From cf563eeb1fb728e9a09839e53c238b8d647a19e4 Mon Sep 17 00:00:00 2001 From: "ravindu.somawansa" Date: Fri, 5 Jul 2024 13:16:45 +0200 Subject: [PATCH 1/9] feat: added bedrock metadata filter and search type parameters --- libs/aws/langchain_aws/retrievers/bedrock.py | 45 +++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 58906e0a..bdf30ceb 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import boto3 from botocore.client import Config @@ -10,16 +10,58 @@ from typing_extensions import Annotated +class SearchFilter(BaseModel): + """Filter configuration for retrieval.""" + + andAll: Optional[List["SearchFilter"]] = None + equals: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = None + greaterThan: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( + None + ) + greaterThanOrEquals: Optional[ + Dict[str, Union[Dict, List, int, float, str, bool, None]] + ] = None + in_: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = Field( + None, alias="in" + ) + lessThan: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = None + lessThanOrEquals: Optional[ + Dict[str, Union[Dict, List, int, float, str, bool, None]] + ] = None + listContains: Optional[ + Dict[str, Union[Dict, List, int, float, str, bool, None]] + ] = None + notEquals: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( + None + ) + notIn: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = Field( + None, alias="notIn" + ) + orAll: Optional[List["SearchFilter"]] = None + startsWith: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( + None + ) + stringContains: Optional[ + Dict[str, Union[Dict, List, int, float, str, bool, None]] + ] = None + + class Config: + allow_population_by_field_name = True + + class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] """Configuration for vector search.""" numberOfResults: int = 4 + filter: Optional[SearchFilter] = None + overrideSearchType: Optional[str] = None # Can be 'HYBRID' or 'SEMANTIC' class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg] """Configuration for retrieval.""" vectorSearchConfiguration: VectorSearchConfig + nextToken: Optional[str] = None class AmazonKnowledgeBasesRetriever(BaseRetriever): @@ -128,6 +170,7 @@ def _get_relevant_documents( response = self.client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, + nextToken=None, retrievalConfiguration=self.retrieval_config.dict(), ) results = response["retrievalResults"] From bce17aa3e451c1c91c19542d515e69bc61eb2ded Mon Sep 17 00:00:00 2001 From: "ravindu.somawansa" Date: Fri, 5 Jul 2024 14:12:04 +0200 Subject: [PATCH 2/9] fix: fixed formating --- libs/aws/langchain_aws/retrievers/bedrock.py | 72 ++++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index bdf30ceb..37a81fc2 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -15,9 +15,9 @@ class SearchFilter(BaseModel): andAll: Optional[List["SearchFilter"]] = None equals: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = None - greaterThan: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( - None - ) + greaterThan: Optional[ + Dict[str, Union[Dict, List, int, float, str, bool, None]] + ] = None greaterThanOrEquals: Optional[ Dict[str, Union[Dict, List, int, float, str, bool, None]] ] = None @@ -31,16 +31,16 @@ class SearchFilter(BaseModel): listContains: Optional[ Dict[str, Union[Dict, List, int, float, str, bool, None]] ] = None - notEquals: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( - None - ) + notEquals: Optional[ + Dict[str, Union[Dict, List, int, float, str, bool, None]] + ] = None notIn: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = Field( None, alias="notIn" ) orAll: Optional[List["SearchFilter"]] = None - startsWith: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( - None - ) + startsWith: Optional[ + Dict[str, Union[Dict, List, int, float, str, bool, None]] + ] = None stringContains: Optional[ Dict[str, Union[Dict, List, int, float, str, bool, None]] ] = None @@ -67,33 +67,33 @@ class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg] class AmazonKnowledgeBasesRetriever(BaseRetriever): """`Amazon Bedrock Knowledge Bases` retrieval. - See https://aws.amazon.com/bedrock/knowledge-bases for more info. - - Args: - knowledge_base_id: Knowledge Base ID. - region_name: The aws region e.g., `us-west-2`. - Fallback to AWS_DEFAULT_REGION env variable or region specified in - ~/.aws/config. - credentials_profile_name: The name of the profile in the ~/.aws/credentials - or ~/.aws/config files, which has either access keys or role information - specified. If not specified, the default credential profile or, if on an - EC2 instance, credentials from IMDS will be used. - client: boto3 client for bedrock agent runtime. - retrieval_config: Configuration for retrieval. - - Example: - .. code-block:: python - - from langchain_community.retrievers import AmazonKnowledgeBasesRetriever - - retriever = AmazonKnowledgeBasesRetriever( - knowledge_base_id="", - retrieval_config={ - "vectorSearchConfiguration": { - "numberOfResults": 4 - } - }, - ) + See https://aws.amazon.com/bedrock/knowledge-bases for more info. + + Args: + knowledge_base_id: Knowledge Base ID. + region_name: The aws region e.g., `us-west-2`. + Fallback to AWS_DEFAULT_REGION env variable or region specified in + ~/.aws/config. + credentials_profile_name: The name of the profile in the ~/.aws/credentials + or ~/.aws/config files, which has either access keys or role information + specified. If not specified, the default credential profile or, if on an + EC2 instance, credentials from IMDS will be used. + client: boto3 client for bedrock agent runtime. + retrieval_config: Configuration for retrieval. + + Example: + .. code-block:: python + + from langchain_community.retrievers import AmazonKnowledgeBasesRetriever + + retriever = AmazonKnowledgeBasesRetriever( + knowledge_base_id="", + retrieval_config={ + "vectorSearchConfiguration": { + "numberOfResults": 4 + } + }, + ) """ knowledge_base_id: str From dc2dd6a108df8e924637862840bb44f98f2a204a Mon Sep 17 00:00:00 2001 From: "ravindu.somawansa" Date: Wed, 10 Jul 2024 19:49:41 +0200 Subject: [PATCH 3/9] feat: implemented next_token parameter and return value --- libs/aws/langchain_aws/retrievers/bedrock.py | 39 ++++++++++++-------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 37a81fc2..52d67c8b 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Literal, Tuple import boto3 from botocore.client import Config @@ -15,9 +15,9 @@ class SearchFilter(BaseModel): andAll: Optional[List["SearchFilter"]] = None equals: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = None - greaterThan: Optional[ - Dict[str, Union[Dict, List, int, float, str, bool, None]] - ] = None + greaterThan: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( + None + ) greaterThanOrEquals: Optional[ Dict[str, Union[Dict, List, int, float, str, bool, None]] ] = None @@ -31,16 +31,16 @@ class SearchFilter(BaseModel): listContains: Optional[ Dict[str, Union[Dict, List, int, float, str, bool, None]] ] = None - notEquals: Optional[ - Dict[str, Union[Dict, List, int, float, str, bool, None]] - ] = None + notEquals: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( + None + ) notIn: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = Field( None, alias="notIn" ) orAll: Optional[List["SearchFilter"]] = None - startsWith: Optional[ - Dict[str, Union[Dict, List, int, float, str, bool, None]] - ] = None + startsWith: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( + None + ) stringContains: Optional[ Dict[str, Union[Dict, List, int, float, str, bool, None]] ] = None @@ -54,7 +54,7 @@ class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] numberOfResults: int = 4 filter: Optional[SearchFilter] = None - overrideSearchType: Optional[str] = None # Can be 'HYBRID' or 'SEMANTIC' + overrideSearchType: Optional[Literal["HYBRID", "SEMANTIC"]] = None class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg] @@ -165,14 +165,20 @@ def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: return filtered_docs def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[Document]: + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + return_token: bool = False, + next_token: Optional[str] = None + ) -> Union[List[Document], Tuple[List[Document], Optional[str]]]: response = self.client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, - nextToken=None, + nextToken=next_token or self.retrieval_config.nextToken, retrievalConfiguration=self.retrieval_config.dict(), ) + new_next_token = response.get("nextToken", None) results = response["retrievalResults"] documents = [] for result in results: @@ -189,4 +195,7 @@ def _get_relevant_documents( ) ) - return self._filter_by_score_confidence(docs=documents) + filtered_documents = self._filter_by_score_confidence(docs=documents) + if return_token: + return (filtered_documents, new_next_token) + return filtered_documents From 0e7ede7363a875f00dcf19848a009fb23c4dce8d Mon Sep 17 00:00:00 2001 From: "ravindu.somawansa" Date: Wed, 10 Jul 2024 20:22:02 +0200 Subject: [PATCH 4/9] fix: improved get_releveant_documents function --- libs/aws/langchain_aws/retrievers/bedrock.py | 91 ++++++++++++-------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 52d67c8b..293c0b1f 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union, Literal, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple, Union import boto3 from botocore.client import Config @@ -9,41 +9,32 @@ from langchain_core.retrievers import BaseRetriever from typing_extensions import Annotated +FilterValue = Union[Dict[str, Any], List[Any], int, float, str, bool, None] + + +class Criterion(BaseModel): + """A model to encapsulate a filter criterion with a key and its value.""" + + key: str + value: FilterValue + class SearchFilter(BaseModel): """Filter configuration for retrieval.""" andAll: Optional[List["SearchFilter"]] = None - equals: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = None - greaterThan: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( - None - ) - greaterThanOrEquals: Optional[ - Dict[str, Union[Dict, List, int, float, str, bool, None]] - ] = None - in_: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = Field( - None, alias="in" - ) - lessThan: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = None - lessThanOrEquals: Optional[ - Dict[str, Union[Dict, List, int, float, str, bool, None]] - ] = None - listContains: Optional[ - Dict[str, Union[Dict, List, int, float, str, bool, None]] - ] = None - notEquals: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( - None - ) - notIn: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = Field( - None, alias="notIn" - ) orAll: Optional[List["SearchFilter"]] = None - startsWith: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = ( - None - ) - stringContains: Optional[ - Dict[str, Union[Dict, List, int, float, str, bool, None]] - ] = None + equals: Optional[Criterion] = None + greaterThan: Optional[Criterion] = None + greaterThanOrEquals: Optional[Criterion] = None + in_: Optional[Criterion] = Field(None, alias="in") + lessThan: Optional[Criterion] = None + lessThanOrEquals: Optional[Criterion] = None + listContains: Optional[Criterion] = None + notEquals: Optional[Criterion] = None + notIn: Optional[Criterion] = Field(None, alias="notIn") + startsWith: Optional[Criterion] = None + stringContains: Optional[Criterion] = None class Config: allow_population_by_field_name = True @@ -169,9 +160,39 @@ def _get_relevant_documents( query: str, *, run_manager: CallbackManagerForRetrieverRun, - return_token: bool = False, - next_token: Optional[str] = None - ) -> Union[List[Document], Tuple[List[Document], Optional[str]]]: + next_token: Optional[str] = None, + ) -> List[Document]: + response = self.client.retrieve( + retrievalQuery={"text": query.strip()}, + knowledgeBaseId=self.knowledge_base_id, + nextToken=next_token, + retrievalConfiguration=self.retrieval_config.dict(), + ) + results = response["retrievalResults"] + documents = [] + for result in results: + content = result["content"]["text"] + result.pop("content") + if "score" not in result: + result["score"] = 0 + if "metadata" in result: + result["source_metadata"] = result.pop("metadata") + documents.append( + Document( + page_content=content, + metadata=result, + ) + ) + + return self._filter_by_score_confidence(docs=documents) + + def _get_relevant_documents_with_token( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + next_token: Optional[str] = None, + ) -> Tuple[List[Document], Optional[str]]: response = self.client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, @@ -196,6 +217,4 @@ def _get_relevant_documents( ) filtered_documents = self._filter_by_score_confidence(docs=documents) - if return_token: - return (filtered_documents, new_next_token) - return filtered_documents + return (filtered_documents, new_next_token) From c96116bd1435ea1c6a942eaf4a70b7beaf007b5c Mon Sep 17 00:00:00 2001 From: "ravindu.somawansa" Date: Fri, 12 Jul 2024 16:23:28 +0200 Subject: [PATCH 5/9] fix: removed argument nextToken --- libs/aws/langchain_aws/retrievers/bedrock.py | 36 +------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 293c0b1f..bab024f9 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -160,12 +160,11 @@ def _get_relevant_documents( query: str, *, run_manager: CallbackManagerForRetrieverRun, - next_token: Optional[str] = None, ) -> List[Document]: response = self.client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, - nextToken=next_token, + nextToken=None, retrievalConfiguration=self.retrieval_config.dict(), ) results = response["retrievalResults"] @@ -185,36 +184,3 @@ def _get_relevant_documents( ) return self._filter_by_score_confidence(docs=documents) - - def _get_relevant_documents_with_token( - self, - query: str, - *, - run_manager: CallbackManagerForRetrieverRun, - next_token: Optional[str] = None, - ) -> Tuple[List[Document], Optional[str]]: - response = self.client.retrieve( - retrievalQuery={"text": query.strip()}, - knowledgeBaseId=self.knowledge_base_id, - nextToken=next_token or self.retrieval_config.nextToken, - retrievalConfiguration=self.retrieval_config.dict(), - ) - new_next_token = response.get("nextToken", None) - results = response["retrievalResults"] - documents = [] - for result in results: - content = result["content"]["text"] - result.pop("content") - if "score" not in result: - result["score"] = 0 - if "metadata" in result: - result["source_metadata"] = result.pop("metadata") - documents.append( - Document( - page_content=content, - metadata=result, - ) - ) - - filtered_documents = self._filter_by_score_confidence(docs=documents) - return (filtered_documents, new_next_token) From 7ad2a4fa2f88856f025d88e3614d7dbe8b41818c Mon Sep 17 00:00:00 2001 From: "ravindu.somawansa" Date: Fri, 12 Jul 2024 16:24:39 +0200 Subject: [PATCH 6/9] fix: removed unused import --- libs/aws/langchain_aws/retrievers/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index bab024f9..b5f0b63c 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Union import boto3 from botocore.client import Config From a508aae026d1cc9d41f6d6cbb928c9996ba6d4b3 Mon Sep 17 00:00:00 2001 From: "ravindu.somawansa" Date: Fri, 12 Jul 2024 16:41:01 +0200 Subject: [PATCH 7/9] fix: simplified filter implementation --- libs/aws/langchain_aws/retrievers/bedrock.py | 30 ++++++++------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index b5f0b63c..c96307f9 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -10,13 +10,7 @@ from typing_extensions import Annotated FilterValue = Union[Dict[str, Any], List[Any], int, float, str, bool, None] - - -class Criterion(BaseModel): - """A model to encapsulate a filter criterion with a key and its value.""" - - key: str - value: FilterValue +Filter = Dict[str, FilterValue] class SearchFilter(BaseModel): @@ -24,17 +18,17 @@ class SearchFilter(BaseModel): andAll: Optional[List["SearchFilter"]] = None orAll: Optional[List["SearchFilter"]] = None - equals: Optional[Criterion] = None - greaterThan: Optional[Criterion] = None - greaterThanOrEquals: Optional[Criterion] = None - in_: Optional[Criterion] = Field(None, alias="in") - lessThan: Optional[Criterion] = None - lessThanOrEquals: Optional[Criterion] = None - listContains: Optional[Criterion] = None - notEquals: Optional[Criterion] = None - notIn: Optional[Criterion] = Field(None, alias="notIn") - startsWith: Optional[Criterion] = None - stringContains: Optional[Criterion] = None + equals: Optional[Filter] = None + greaterThan: Optional[Filter] = None + greaterThanOrEquals: Optional[Filter] = None + in_: Optional[Filter] = Field(None, alias="in") + lessThan: Optional[Filter] = None + lessThanOrEquals: Optional[Filter] = None + listContains: Optional[Filter] = None + notEquals: Optional[Filter] = None + notIn: Optional[Filter] = Field(None, alias="notIn") + startsWith: Optional[Filter] = None + stringContains: Optional[Filter] = None class Config: allow_population_by_field_name = True From 674a1162a58b3d547620cb72d737853d95b41e4d Mon Sep 17 00:00:00 2001 From: "ravindu.somawansa" Date: Wed, 17 Jul 2024 16:47:26 +0200 Subject: [PATCH 8/9] fix: fixed pr review --- libs/aws/langchain_aws/retrievers/bedrock.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index c96307f9..5a4f8bc3 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -158,7 +158,6 @@ def _get_relevant_documents( response = self.client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, - nextToken=None, retrievalConfiguration=self.retrieval_config.dict(), ) results = response["retrievalResults"] From e3a2a27ad2a018bb42d49dcd658db63e4b8d8101 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Wed, 17 Jul 2024 15:33:31 -0700 Subject: [PATCH 9/9] Small fix to exclude empty filter. --- libs/aws/langchain_aws/retrievers/bedrock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 5a4f8bc3..d899eadc 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -158,7 +158,7 @@ def _get_relevant_documents( response = self.client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, - retrievalConfiguration=self.retrieval_config.dict(), + retrievalConfiguration=self.retrieval_config.dict(exclude_none=True), ) results = response["retrievalResults"] documents = []