diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 58906e0a..d899eadc 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, Union import boto3 from botocore.client import Config @@ -9,49 +9,76 @@ from langchain_core.retrievers import BaseRetriever from typing_extensions import Annotated +FilterValue = Union[Dict[str, Any], List[Any], int, float, str, bool, None] +Filter = Dict[str, FilterValue] + + +class SearchFilter(BaseModel): + """Filter configuration for retrieval.""" + + andAll: Optional[List["SearchFilter"]] = None + orAll: Optional[List["SearchFilter"]] = None + equals: Optional[Filter] = None + greaterThan: Optional[Filter] = None + greaterThanOrEquals: Optional[Filter] = None + in_: Optional[Filter] = Field(None, alias="in") + lessThan: Optional[Filter] = None + lessThanOrEquals: Optional[Filter] = None + listContains: Optional[Filter] = None + notEquals: Optional[Filter] = None + notIn: Optional[Filter] = Field(None, alias="notIn") + startsWith: Optional[Filter] = None + stringContains: Optional[Filter] = None + + class Config: + allow_population_by_field_name = True + class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] """Configuration for vector search.""" numberOfResults: int = 4 + filter: Optional[SearchFilter] = None + overrideSearchType: Optional[Literal["HYBRID", "SEMANTIC"]] = None class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg] """Configuration for retrieval.""" vectorSearchConfiguration: VectorSearchConfig + nextToken: Optional[str] = None class AmazonKnowledgeBasesRetriever(BaseRetriever): """`Amazon Bedrock Knowledge Bases` retrieval. - See https://aws.amazon.com/bedrock/knowledge-bases for more info. - - Args: - knowledge_base_id: Knowledge Base ID. - region_name: The aws region e.g., `us-west-2`. - Fallback to AWS_DEFAULT_REGION env variable or region specified in - ~/.aws/config. - credentials_profile_name: The name of the profile in the ~/.aws/credentials - or ~/.aws/config files, which has either access keys or role information - specified. If not specified, the default credential profile or, if on an - EC2 instance, credentials from IMDS will be used. - client: boto3 client for bedrock agent runtime. - retrieval_config: Configuration for retrieval. - - Example: - .. code-block:: python - - from langchain_community.retrievers import AmazonKnowledgeBasesRetriever - - retriever = AmazonKnowledgeBasesRetriever( - knowledge_base_id="", - retrieval_config={ - "vectorSearchConfiguration": { - "numberOfResults": 4 - } - }, - ) + See https://aws.amazon.com/bedrock/knowledge-bases for more info. + + Args: + knowledge_base_id: Knowledge Base ID. + region_name: The aws region e.g., `us-west-2`. + Fallback to AWS_DEFAULT_REGION env variable or region specified in + ~/.aws/config. + credentials_profile_name: The name of the profile in the ~/.aws/credentials + or ~/.aws/config files, which has either access keys or role information + specified. If not specified, the default credential profile or, if on an + EC2 instance, credentials from IMDS will be used. + client: boto3 client for bedrock agent runtime. + retrieval_config: Configuration for retrieval. + + Example: + .. code-block:: python + + from langchain_community.retrievers import AmazonKnowledgeBasesRetriever + + retriever = AmazonKnowledgeBasesRetriever( + knowledge_base_id="", + retrieval_config={ + "vectorSearchConfiguration": { + "numberOfResults": 4 + } + }, + ) """ knowledge_base_id: str @@ -123,12 +150,15 @@ def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: return filtered_docs def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: response = self.client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, - retrievalConfiguration=self.retrieval_config.dict(), + retrievalConfiguration=self.retrieval_config.dict(exclude_none=True), ) results = response["retrievalResults"] documents = []