Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding bedrock metadata filter and search type parameters #104

Merged
merged 10 commits into from
Jul 17, 2024
99 changes: 71 additions & 28 deletions libs/aws/langchain_aws/retrievers/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import boto3
from botocore.client import Config
Expand All @@ -10,48 +10,90 @@
from typing_extensions import Annotated


class SearchFilter(BaseModel):
"""Filter configuration for retrieval."""

andAll: Optional[List["SearchFilter"]] = None
equals: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = None
greaterThan: Optional[
Dict[str, Union[Dict, List, int, float, str, bool, None]]
] = None
greaterThanOrEquals: Optional[
Dict[str, Union[Dict, List, int, float, str, bool, None]]
] = None
in_: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = Field(
None, alias="in"
)
lessThan: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = None
lessThanOrEquals: Optional[
Dict[str, Union[Dict, List, int, float, str, bool, None]]
] = None
listContains: Optional[
Dict[str, Union[Dict, List, int, float, str, bool, None]]
] = None
notEquals: Optional[
Dict[str, Union[Dict, List, int, float, str, bool, None]]
] = None
notIn: Optional[Dict[str, Union[Dict, List, int, float, str, bool, None]]] = Field(
None, alias="notIn"
)
orAll: Optional[List["SearchFilter"]] = None
startsWith: Optional[
Dict[str, Union[Dict, List, int, float, str, bool, None]]
] = None
stringContains: Optional[
Dict[str, Union[Dict, List, int, float, str, bool, None]]
] = None

class Config:
ravediamond marked this conversation as resolved.
Show resolved Hide resolved
allow_population_by_field_name = True


class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
"""Configuration for vector search."""

numberOfResults: int = 4
filter: Optional[SearchFilter] = None
overrideSearchType: Optional[str] = None # Can be 'HYBRID' or 'SEMANTIC'

ravediamond marked this conversation as resolved.
Show resolved Hide resolved

class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
"""Configuration for retrieval."""

vectorSearchConfiguration: VectorSearchConfig
nextToken: Optional[str] = None


class AmazonKnowledgeBasesRetriever(BaseRetriever):
"""`Amazon Bedrock Knowledge Bases` retrieval.

See https://aws.amazon.com/bedrock/knowledge-bases for more info.

Args:
knowledge_base_id: Knowledge Base ID.
region_name: The aws region e.g., `us-west-2`.
Fallback to AWS_DEFAULT_REGION env variable or region specified in
~/.aws/config.
credentials_profile_name: The name of the profile in the ~/.aws/credentials
or ~/.aws/config files, which has either access keys or role information
specified. If not specified, the default credential profile or, if on an
EC2 instance, credentials from IMDS will be used.
client: boto3 client for bedrock agent runtime.
retrieval_config: Configuration for retrieval.

Example:
.. code-block:: python

from langchain_community.retrievers import AmazonKnowledgeBasesRetriever

retriever = AmazonKnowledgeBasesRetriever(
knowledge_base_id="<knowledge-base-id>",
retrieval_config={
"vectorSearchConfiguration": {
"numberOfResults": 4
}
},
)
See https://aws.amazon.com/bedrock/knowledge-bases for more info.

Args:
knowledge_base_id: Knowledge Base ID.
region_name: The aws region e.g., `us-west-2`.
Fallback to AWS_DEFAULT_REGION env variable or region specified in
~/.aws/config.
credentials_profile_name: The name of the profile in the ~/.aws/credentials
or ~/.aws/config files, which has either access keys or role information
specified. If not specified, the default credential profile or, if on an
EC2 instance, credentials from IMDS will be used.
client: boto3 client for bedrock agent runtime.
retrieval_config: Configuration for retrieval.

Example:
.. code-block:: python

from langchain_community.retrievers import AmazonKnowledgeBasesRetriever

retriever = AmazonKnowledgeBasesRetriever(
knowledge_base_id="<knowledge-base-id>",
retrieval_config={
"vectorSearchConfiguration": {
"numberOfResults": 4
}
},
)
"""

knowledge_base_id: str
Expand Down Expand Up @@ -128,6 +170,7 @@ def _get_relevant_documents(
response = self.client.retrieve(
retrievalQuery={"text": query.strip()},
knowledgeBaseId=self.knowledge_base_id,
nextToken=None,
retrievalConfiguration=self.retrieval_config.dict(),
ravediamond marked this conversation as resolved.
Show resolved Hide resolved
ravediamond marked this conversation as resolved.
Show resolved Hide resolved
)
results = response["retrievalResults"]
Expand Down
Loading