diff --git a/libs/aws/Makefile b/libs/aws/Makefile index 6366a9d5..98577244 100644 --- a/libs/aws/Makefile +++ b/libs/aws/Makefile @@ -18,7 +18,7 @@ test tests integration_test integration_tests: PYTHON_FILES=. MYPY_CACHE=.mypy_cache lint format: PYTHON_FILES=. -lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/aws --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$') +lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/aws --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') lint_package: PYTHON_FILES=langchain_aws lint_tests: PYTHON_FILES=tests lint_tests: MYPY_CACHE=.mypy_cache_test diff --git a/libs/aws/langchain_aws/chat_models/bedrock.py b/libs/aws/langchain_aws/chat_models/bedrock.py index 66f40fba..76ceebfa 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock.py +++ b/libs/aws/langchain_aws/chat_models/bedrock.py @@ -36,7 +36,10 @@ from langchain_core.tools import BaseTool from langchain_aws.function_calling import convert_to_anthropic_tool, get_system_message -from langchain_aws.llms.bedrock import BedrockBase +from langchain_aws.llms.bedrock import ( + BedrockBase, + _combine_generation_info_for_llm_result, +) from langchain_aws.utils import ( get_num_tokens_anthropic, get_token_ids_anthropic, @@ -379,7 +382,13 @@ def _stream( **kwargs, ): delta = chunk.text - yield ChatGenerationChunk(message=AIMessageChunk(content=delta)) + yield ChatGenerationChunk( + message=AIMessageChunk( + content=delta, response_metadata=chunk.generation_info + ) + if chunk.generation_info is not None + else AIMessageChunk(content=delta) + ) def _generate( self, @@ -389,11 +398,18 @@ def _generate( **kwargs: Any, ) -> ChatResult: completion = "" - llm_output: Dict[str, Any] = {"model_id": self.model_id} - usage_info: Dict[str, Any] = {} + llm_output: Dict[str, Any] = {} + provider_stop_reason_code = self.provider_stop_reason_key_map.get( + self._get_provider(), "stop_reason" + ) if self.streaming: + response_metadata: List[Dict[str, Any]] = [] for chunk in self._stream(messages, stop, run_manager, **kwargs): completion += chunk.text + response_metadata.append(chunk.message.response_metadata) + llm_output = _combine_generation_info_for_llm_result( + response_metadata, provider_stop_reason_code + ) else: provider = self._get_provider() prompt, system, formatted_messages = None, None, None @@ -416,7 +432,7 @@ def _generate( if stop: params["stop_sequences"] = stop - completion, usage_info = self._prepare_input_and_invoke( + completion, llm_output = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, @@ -425,14 +441,11 @@ def _generate( **params, ) - llm_output["usage"] = usage_info - + llm_output["model_id"] = self.model_id return ChatResult( generations=[ ChatGeneration( - message=AIMessage( - content=completion, additional_kwargs={"usage": usage_info} - ) + message=AIMessage(content=completion, additional_kwargs=llm_output) ) ], llm_output=llm_output, @@ -443,7 +456,7 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: final_output = {} for output in llm_outputs: output = output or {} - usage = output.pop("usage", {}) + usage = output.get("usage", {}) for token_type, token_count in usage.items(): final_usage[token_type] += token_count final_output.update(output) diff --git a/libs/aws/langchain_aws/llms/__init__.py b/libs/aws/langchain_aws/llms/__init__.py index 3255494a..5a4fa68b 100644 --- a/libs/aws/langchain_aws/llms/__init__.py +++ b/libs/aws/langchain_aws/llms/__init__.py @@ -3,6 +3,7 @@ Bedrock, BedrockBase, BedrockLLM, + LLMInputOutputAdapter, ) from langchain_aws.llms.sagemaker_endpoint import SagemakerEndpoint @@ -11,5 +12,6 @@ "Bedrock", "BedrockBase", "BedrockLLM", + "LLMInputOutputAdapter", "SagemakerEndpoint", ] diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 1f900024..c20e3365 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -12,6 +12,7 @@ Mapping, Optional, Tuple, + Union, ) from langchain_core._api.deprecation import deprecated @@ -20,7 +21,7 @@ CallbackManagerForLLMRun, ) from langchain_core.language_models import LLM, BaseLanguageModel -from langchain_core.outputs import GenerationChunk +from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import Extra, Field, root_validator from langchain_core.utils import get_from_dict_or_env @@ -80,18 +81,98 @@ def _human_assistant_format(input_text: str) -> str: def _stream_response_to_generation_chunk( - stream_response: Dict[str, Any], -) -> GenerationChunk: + stream_response: Dict[str, Any], provider: str, output_key: str, messages_api: bool +) -> Union[GenerationChunk, None]: """Convert a stream response to a generation chunk.""" - if not stream_response["delta"]: - return GenerationChunk(text="") - return GenerationChunk( - text=stream_response["delta"]["text"], - generation_info=dict( - finish_reason=stream_response.get("stop_reason", None), - ), + if messages_api: + msg_type = stream_response.get("type") + if msg_type == "message_start": + usage_info = stream_response.get("message", {}).get("usage", None) + usage_info = _nest_usage_info_token_counts(usage_info) + generation_info = {"usage": usage_info} + return GenerationChunk(text="", generation_info=generation_info) + elif msg_type == "content_block_delta": + if not stream_response["delta"]: + return GenerationChunk(text="") + return GenerationChunk( + text=stream_response["delta"]["text"], + generation_info=dict( + stop_reason=stream_response.get("stop_reason", None), + ), + ) + elif msg_type == "message_delta": + usage_info = stream_response.get("usage", None) + usage_info = _nest_usage_info_token_counts(usage_info) + stop_reason = stream_response.get("delta", {}).get("stop_reason") + generation_info = {"stop_reason": stop_reason, "usage": usage_info} + return GenerationChunk(text="", generation_info=generation_info) + else: + return None + else: + # chunk obj format varies with provider + generation_info = {k: v for k, v in stream_response.items() if k != output_key} + return GenerationChunk( + text=( + stream_response[output_key] + if provider != "mistral" + else stream_response[output_key][0]["text"] + ), + generation_info=generation_info, + ) + + +def _nest_usage_info_token_counts(usage_info: dict) -> dict: + """ + Sticking usage info for token counts into lists to + deal with langchain_core.utils.merge_dicts incompatibility + in which integers must be equal to be merged + as seen here: https://github.com/langchain-ai/langchain-aws/pull/20#issuecomment-2118166376 + """ + if "input_tokens" in usage_info: + usage_info["input_tokens"] = [usage_info["input_tokens"]] + if "output_tokens" in usage_info: + usage_info["output_tokens"] = [usage_info["output_tokens"]] + return usage_info + + +def _combine_generation_info_for_llm_result( + chunks_generation_info: List[Dict[str, Any]], provider_stop_code: str +) -> Dict[str, Any]: + """ + Returns usage and stop reason information with the intent to pack into an LLMResult + Takes a list of generation_info from GenerationChunks + If the messages api is being used, + the generation_info from some of these chunks should contain "usage" keys + if not, the token counts should be found within "amazon-bedrock-invocationMetrics" + """ + total_usage_info = {"prompt_tokens": 0, "completion_tokens": 0} + stop_reason = "" + for generation_info in chunks_generation_info: + if "usage" in generation_info: + usage_info = generation_info["usage"] + if "input_tokens" in usage_info: + total_usage_info["prompt_tokens"] += sum(usage_info["input_tokens"]) + if "output_tokens" in usage_info: + total_usage_info["completion_tokens"] += sum( + usage_info["output_tokens"] + ) + if "amazon-bedrock-invocationMetrics" in generation_info: + usage_info = generation_info["amazon-bedrock-invocationMetrics"] + if "inputTokenCount" in usage_info: + total_usage_info["prompt_tokens"] += usage_info["inputTokenCount"] + if "outputTokenCount" in usage_info: + total_usage_info["completion_tokens"] += usage_info["outputTokenCount"] + + if provider_stop_code is not None and provider_stop_code in generation_info: + # uses the last stop reason + stop_reason = generation_info[provider_stop_code] + + total_usage_info["total_tokens"] = ( + total_usage_info["prompt_tokens"] + total_usage_info["completion_tokens"] ) + return {"usage": total_usage_info, "stop_reason": stop_reason} + class LLMInputOutputAdapter: """Adapter class to prepare the inputs from Langchain to a format @@ -176,6 +257,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict: "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, + "stop_reason": response_body.get("stop_reason"), } @classmethod @@ -219,39 +301,27 @@ def prepare_output_stream( ): return - elif messages_api and (chunk_obj.get("type") == "content_block_stop"): + elif messages_api and (chunk_obj.get("type") == "message_stop"): return - if messages_api and chunk_obj.get("type") in ( - "message_start", - "content_block_start", - "content_block_delta", - ): - if chunk_obj.get("type") == "content_block_delta": - chk = _stream_response_to_generation_chunk(chunk_obj) - yield chk - else: - continue + generation_chunk = _stream_response_to_generation_chunk( + chunk_obj, + provider=provider, + output_key=output_key, + messages_api=messages_api, + ) + if generation_chunk: + yield generation_chunk else: - # chunk obj format varies with provider - yield GenerationChunk( - text=( - chunk_obj[output_key] - if provider != "mistral" - else chunk_obj[output_key][0]["text"] - ), - generation_info={ - GUARDRAILS_BODY_KEY: ( - chunk_obj.get(GUARDRAILS_BODY_KEY) - if GUARDRAILS_BODY_KEY in chunk_obj - else None - ), - }, - ) + continue @classmethod async def aprepare_output_stream( - cls, provider: str, response: Any, stop: Optional[List[str]] = None + cls, + provider: str, + response: Any, + stop: Optional[List[str]] = None, + messages_api: bool = False, ) -> AsyncIterator[GenerationChunk]: stream = response.get("body") @@ -283,13 +353,16 @@ async def aprepare_output_stream( ): return - yield GenerationChunk( - text=( - chunk_obj[output_key] - if provider != "mistral" - else chunk_obj[output_key][0]["text"] - ) + generation_chunk = _stream_response_to_generation_chunk( + chunk_obj, + provider=provider, + output_key=output_key, + messages_api=messages_api, ) + if generation_chunk: + yield generation_chunk + else: + continue class BedrockBase(BaseLanguageModel, ABC): @@ -342,6 +415,14 @@ class BedrockBase(BaseLanguageModel, ABC): "mistral": "stop_sequences", } + provider_stop_reason_key_map: Mapping[str, str] = { + "anthropic": "stop_reason", + "amazon": "completionReason", + "ai21": "finishReason", + "cohere": "finish_reason", + "mistral": "stop_reason", + } + guardrails: Optional[Mapping[str, Any]] = { "trace": None, "guardrailIdentifier": None, @@ -540,7 +621,7 @@ def _prepare_input_and_invoke( try: response = self.client.invoke_model(**request_options) - text, body, usage_info = LLMInputOutputAdapter.prepare_output( + text, body, usage_info, stop_reason = LLMInputOutputAdapter.prepare_output( provider, response ).values() @@ -550,12 +631,14 @@ def _prepare_input_and_invoke( if stop is not None: text = enforce_stop_tokens(text, stop) + llm_output = {"usage": usage_info, "stop_reason": stop_reason} + # Verify and raise a callback error if any intervention occurs or a signal is # sent from a Bedrock service, # such as when guardrails are triggered. services_trace = self._get_bedrock_services_signal(body) # type: ignore[arg-type] - if services_trace.get("signal") and run_manager is not None: + if run_manager is not None and services_trace.get("signal"): run_manager.on_llm_error( Exception( f"Error raised by bedrock service: {services_trace.get('reason')}" @@ -563,7 +646,7 @@ def _prepare_input_and_invoke( **services_trace, ) - return text, usage_info + return text, llm_output def _get_bedrock_services_signal(self, body: dict) -> dict: """ @@ -666,6 +749,8 @@ def _prepare_input_and_invoke_stream( async def _aprepare_input_and_invoke_stream( self, prompt: str, + system: Optional[str] = None, + messages: Optional[List[Dict]] = None, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, @@ -686,7 +771,11 @@ async def _aprepare_input_and_invoke_stream( params = {**_model_kwargs, **kwargs} input_body = LLMInputOutputAdapter.prepare_input( - provider=provider, prompt=prompt, model_kwargs=params + provider=provider, + prompt=prompt, + system=system, + messages=messages, + model_kwargs=params, ) body = json.dumps(input_body) @@ -701,7 +790,7 @@ async def _aprepare_input_and_invoke_stream( ) async for chunk in LLMInputOutputAdapter.aprepare_output_stream( - provider, response, stop + provider, response, stop, True if messages else False ): yield chunk if run_manager is not None and asyncio.iscoroutinefunction( @@ -829,17 +918,47 @@ def _call( response = llm("Tell me a joke.") """ + provider = self._get_provider() + provider_stop_reason_code = self.provider_stop_reason_key_map.get( + provider, "stop_reason" + ) + if self.streaming: + all_chunks: List[GenerationChunk] = [] completion = "" for chunk in self._stream( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ): completion += chunk.text + all_chunks.append(chunk) + + if run_manager is not None: + chunks_generation_info = [ + chunk.generation_info + for chunk in all_chunks + if chunk.generation_info is not None + ] + llm_output = _combine_generation_info_for_llm_result( + chunks_generation_info, provider_stop_code=provider_stop_reason_code + ) + all_generations = [ + Generation(text=chunk.text, generation_info=chunk.generation_info) + for chunk in all_chunks + ] + run_manager.on_llm_end( + LLMResult(generations=[all_generations], llm_output=llm_output) + ) + return completion - text, _ = self._prepare_input_and_invoke( + text, llm_output = self._prepare_input_and_invoke( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ) + if run_manager is not None: + run_manager.on_llm_end( + LLMResult(generations=[[Generation(text=text)]], llm_output=llm_output) + ) + return text async def _astream( @@ -893,13 +1012,36 @@ async def _acall( if not self.streaming: raise ValueError("Streaming must be set to True for async operations. ") + provider = self._get_provider() + provider_stop_reason_code = self.provider_stop_reason_key_map.get( + provider, "stop_reason" + ) + chunks = [ - chunk.text + chunk async for chunk in self._astream( prompt=prompt, stop=stop, run_manager=run_manager, **kwargs ) ] - return "".join(chunks) + + if run_manager is not None: + chunks_generation_info = [ + chunk.generation_info + for chunk in chunks + if chunk.generation_info is not None + ] + llm_output = _combine_generation_info_for_llm_result( + chunks_generation_info, provider_stop_code=provider_stop_reason_code + ) + generations = [ + Generation(text=chunk.text, generation_info=chunk.generation_info) + for chunk in chunks + ] + await run_manager.on_llm_end( + LLMResult(generations=[generations], llm_output=llm_output) + ) + + return "".join([chunk.text for chunk in chunks]) def get_num_tokens(self, text: str) -> int: if self._model_is_anthropic: diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index f313f7fd..58906e0a 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -5,8 +5,9 @@ from botocore.exceptions import UnknownServiceError from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.retrievers import BaseRetriever +from typing_extensions import Annotated class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] @@ -59,6 +60,7 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever): endpoint_url: Optional[str] = None client: Any retrieval_config: RetrievalConfig + min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)] @root_validator(pre=True) def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: @@ -103,6 +105,23 @@ def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: "profile name are valid." ) from e + def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: + """ + Filter out the records that have a score confidence + less than the required threshold. + """ + if not self.min_score_confidence: + return docs + filtered_docs = [ + item + for item in docs + if ( + item.metadata.get("score") is not None + and item.metadata.get("score", 0.0) >= self.min_score_confidence + ) + ] + return filtered_docs + def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: @@ -127,4 +146,4 @@ def _get_relevant_documents( ) ) - return documents + return self._filter_by_score_confidence(docs=documents) diff --git a/libs/aws/langchain_aws/retrievers/kendra.py b/libs/aws/langchain_aws/retrievers/kendra.py index b4480cae..5e7b5fe1 100644 --- a/libs/aws/langchain_aws/retrievers/kendra.py +++ b/libs/aws/langchain_aws/retrievers/kendra.py @@ -444,7 +444,7 @@ def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]: def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: """ Filter out the records that have a score confidence - greater than the required threshold. + less than the required threshold. """ if not self.min_score_confidence: return docs diff --git a/libs/aws/pyproject.toml b/libs/aws/pyproject.toml index e5c20e5f..ec43119f 100644 --- a/libs/aws/pyproject.toml +++ b/libs/aws/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-aws" -version = "0.1.4" +version = "0.1.6" description = "An integration package connecting AWS and LangChain" authors = [] readme = "README.md" diff --git a/libs/aws/tests/callbacks.py b/libs/aws/tests/callbacks.py index 66b54256..3a3902a0 100644 --- a/libs/aws/tests/callbacks.py +++ b/libs/aws/tests/callbacks.py @@ -5,6 +5,7 @@ from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.messages import BaseMessage +from langchain_core.outputs import LLMResult from langchain_core.pydantic_v1 import BaseModel @@ -271,6 +272,29 @@ def on_chat_model_start( self.on_chat_model_start_common() +class FakeCallbackHandlerWithTokenCounts(FakeCallbackHandler): + input_token_count: int = 0 + output_token_count: int = 0 + stop_reason: Union[str, None] = None + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Union[UUID, None] = None, + **kwargs: Any, + ) -> Any: + if response.llm_output is not None: + self.input_token_count += response.llm_output.get("usage", {}).get( + "prompt_tokens", None + ) + self.output_token_count += response.llm_output.get("usage", {}).get( + "completion_tokens", None + ) + self.stop_reason = response.llm_output.get("stop_reason", None) + + class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin): """Fake async callback handler for testing.""" diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py index fcb3f3fa..18d0f79f 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock.py @@ -12,9 +12,10 @@ ) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.runnables import RunnableConfig from langchain_aws.chat_models.bedrock import ChatBedrock -from tests.callbacks import FakeCallbackHandler +from tests.callbacks import FakeCallbackHandler, FakeCallbackHandlerWithTokenCounts @pytest.fixture @@ -111,6 +112,7 @@ def on_llm_end( chat = ChatBedrock( # type: ignore[call-arg] model_id="anthropic.claude-v2", callbacks=[callback], + model_kwargs={"temperature": 0}, ) list(chat.stream("hi")) generation = callback.saved_things["generation"] @@ -187,7 +189,7 @@ class GetWeather(BaseModel): llm_with_tools = chat.bind_tools([GetWeather]) messages = [ - SystemMessage(content="anwser only in french"), + SystemMessage(content="answer only in french"), HumanMessage(content="what is the weather like in San Francisco"), ] @@ -196,6 +198,27 @@ class GetWeather(BaseModel): assert isinstance(response.content, str) +@pytest.mark.scheduled +@pytest.mark.parametrize("streaming", [True, False]) +def test_chat_bedrock_token_callbacks(streaming: bool) -> None: + """ + Test that streaming correctly invokes on_llm_end + and stores token counts and stop reason. + """ + callback_handler = FakeCallbackHandlerWithTokenCounts() + chat = ChatBedrock( # type: ignore[call-arg] + model_id="anthropic.claude-v2", + streaming=streaming, + verbose=True, + ) + message = HumanMessage(content="Hello") + response = chat.invoke([message], RunnableConfig(callbacks=[callback_handler])) + assert callback_handler.input_token_count > 0 + assert callback_handler.output_token_count > 0 + assert callback_handler.stop_reason is not None + assert isinstance(response, BaseMessage) + + @pytest.mark.scheduled def test_function_call_invoke_without_system(chat: ChatBedrock) -> None: class GetWeather(BaseModel): @@ -218,7 +241,7 @@ class GetWeather(BaseModel): llm_with_tools = chat.bind_tools([GetWeather]) messages = [ - SystemMessage(content="anwser only in french"), + SystemMessage(content="answer only in french"), HumanMessage(content="what is the weather like in San Francisco"), ] diff --git a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py index ae48ffef..54eb8bf3 100644 --- a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py +++ b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py @@ -17,6 +17,7 @@ def retriever(mock_client: Mock) -> AmazonKnowledgeBasesRetriever: knowledge_base_id="test-knowledge-base", client=mock_client, retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 4}}, # type: ignore[arg-type] + min_score_confidence=0.0, ) @@ -78,3 +79,44 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore knowledgeBaseId="test-knowledge-base", retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": 4}}, ) + + +def test_get_relevant_documents_with_score(retriever, mock_client) -> None: # type: ignore[no-untyped-def] + response = { + "retrievalResults": [ + { + "content": {"text": "This is the first result."}, + "location": "location1", + "score": 0.9, + }, + { + "content": {"text": "This is the second result."}, + "location": "location2", + "score": 0.8, + }, + {"content": {"text": "This is the third result."}, "location": "location3"}, + { + "content": {"text": "This is the fourth result."}, + "metadata": {"key1": "value1", "key2": "value2"}, + }, + ] + } + mock_client.retrieve.return_value = response + + query = "test query" + + expected_documents = [ + Document( + page_content="This is the first result.", + metadata={"location": "location1", "score": 0.9}, + ), + Document( + page_content="This is the second result.", + metadata={"location": "location2", "score": 0.8}, + ), + ] + + retriever.min_score_confidence = 0.80 + documents = retriever.invoke(query) + + assert documents == expected_documents diff --git a/libs/aws/tests/unit_tests/llms/test_bedrock.py b/libs/aws/tests/unit_tests/llms/test_bedrock.py index 3a7a0d41..7693cb19 100644 --- a/libs/aws/tests/unit_tests/llms/test_bedrock.py +++ b/libs/aws/tests/unit_tests/llms/test_bedrock.py @@ -1,3 +1,5 @@ +# type:ignore + import json from typing import AsyncGenerator, Dict from unittest.mock import MagicMock, patch @@ -7,6 +9,7 @@ from langchain_aws import BedrockLLM from langchain_aws.llms.bedrock import ( ALTERNATION_ERROR, + LLMInputOutputAdapter, _human_assistant_format, ) @@ -306,3 +309,141 @@ async def test_bedrock_async_streaming_call() -> None: assert chunks[0] == "nice" assert chunks[1] == " to meet" assert chunks[2] == " you" + + +@pytest.fixture +def mistral_response(): + body = MagicMock() + body.read.return_value = json.dumps( + {"outputs": [{"text": "This is the Mistral output text."}]} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "18", + "x-amzn-bedrock-output-token-count": "28", + } + }, + ) + + return response + + +@pytest.fixture +def cohere_response(): + body = MagicMock() + body.read.return_value = json.dumps( + {"generations": [{"text": "This is the Cohere output text."}]} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "12", + "x-amzn-bedrock-output-token-count": "22", + } + }, + ) + return response + + +@pytest.fixture +def anthropic_response(): + body = MagicMock() + body.read.return_value = json.dumps( + {"completion": "This is the output text."} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "10", + "x-amzn-bedrock-output-token-count": "20", + } + }, + ) + return response + + +@pytest.fixture +def ai21_response(): + body = MagicMock() + body.read.return_value = json.dumps( + {"completions": [{"data": {"text": "This is the AI21 output text."}}]} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "15", + "x-amzn-bedrock-output-token-count": "25", + } + }, + ) + return response + + +@pytest.fixture +def response_with_stop_reason(): + body = MagicMock() + body.read.return_value = json.dumps( + {"completion": "This is the output text.", "stop_reason": "length"} + ).encode() + response = dict( + body=body, + ResponseMetadata={ + "HTTPHeaders": { + "x-amzn-bedrock-input-token-count": "10", + "x-amzn-bedrock-output-token-count": "20", + } + }, + ) + return response + + +def test_prepare_output_for_mistral(mistral_response): + result = LLMInputOutputAdapter.prepare_output("mistral", mistral_response) + assert result["text"] == "This is the Mistral output text." + assert result["usage"]["prompt_tokens"] == 18 + assert result["usage"]["completion_tokens"] == 28 + assert result["usage"]["total_tokens"] == 46 + assert result["stop_reason"] is None + + +def test_prepare_output_for_cohere(cohere_response): + result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response) + assert result["text"] == "This is the Cohere output text." + assert result["usage"]["prompt_tokens"] == 12 + assert result["usage"]["completion_tokens"] == 22 + assert result["usage"]["total_tokens"] == 34 + assert result["stop_reason"] is None + + +def test_prepare_output_with_stop_reason(response_with_stop_reason): + result = LLMInputOutputAdapter.prepare_output( + "anthropic", response_with_stop_reason + ) + assert result["text"] == "This is the output text." + assert result["usage"]["prompt_tokens"] == 10 + assert result["usage"]["completion_tokens"] == 20 + assert result["usage"]["total_tokens"] == 30 + assert result["stop_reason"] == "length" + + +def test_prepare_output_for_anthropic(anthropic_response): + result = LLMInputOutputAdapter.prepare_output("anthropic", anthropic_response) + assert result["text"] == "This is the output text." + assert result["usage"]["prompt_tokens"] == 10 + assert result["usage"]["completion_tokens"] == 20 + assert result["usage"]["total_tokens"] == 30 + assert result["stop_reason"] is None + + +def test_prepare_output_for_ai21(ai21_response): + result = LLMInputOutputAdapter.prepare_output("ai21", ai21_response) + assert result["text"] == "This is the AI21 output text." + assert result["usage"]["prompt_tokens"] == 15 + assert result["usage"]["completion_tokens"] == 25 + assert result["usage"]["total_tokens"] == 40 + assert result["stop_reason"] is None diff --git a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py index 007ad139..b243db10 100644 --- a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py +++ b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py @@ -55,3 +55,31 @@ def test_retriever_invoke(amazon_retriever, mock_client): } assert documents[2].page_content == "result3" assert documents[2].metadata == {"score": 0} + + +def test_retriever_invoke_with_score(amazon_retriever, mock_client): + query = "test query" + mock_client.retrieve.return_value = { + "retrievalResults": [ + {"content": {"text": "result1"}, "metadata": {"key": "value1"}}, + { + "content": {"text": "result2"}, + "metadata": {"key": "value2"}, + "score": 1, + "location": "testLocation", + }, + {"content": {"text": "result3"}}, + ] + } + + amazon_retriever.min_score_confidence = 0.6 + documents = amazon_retriever.invoke(query, run_manager=None) + + assert len(documents) == 1 + assert isinstance(documents[0], Document) + assert documents[0].page_content == "result2" + assert documents[0].metadata == { + "score": 1, + "source_metadata": {"key": "value2"}, + "location": "testLocation", + }