From 6baea9e9640af5d4bb3e276729184e0de45a5964 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 16 Jul 2024 14:29:17 -0400 Subject: [PATCH 1/3] add broken test --- .../chat_models/test_bedrock_converse.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index c6c23ea4..34246422 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -1,10 +1,11 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Literal, Type import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_aws import ChatBedrockConverse @@ -31,6 +32,18 @@ def standard_chat_model_params(self) -> dict: def supports_image_inputs(self) -> bool: return True + def test_structured_output_snake_case(self, model: BaseChatModel) -> None: + class ClassifyQuery(BaseModel): + """Classify a query.""" + + query_type: Literal["cat", "dog"] = Field( + description="Classify a query as related to cats or dogs." + ) + + chat = model.with_structured_output(ClassifyQuery) + for chunk in chat.stream("How big are cats?"): + assert isinstance(chunk, ClassifyQuery) + @pytest.mark.skip(reason="Needs guardrails setup to run.") def test_guardrails() -> None: From ad406e8ab48ba41b63da90b56c272d34b2abf995 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 16 Jul 2024 14:31:41 -0400 Subject: [PATCH 2/3] fix test --- .../aws/langchain_aws/chat_models/bedrock_converse.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index bb303e03..f99e7ac6 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -815,12 +815,17 @@ def _camel_to_snake_keys(obj: _T) -> _T: def _snake_to_camel_keys(obj: _T) -> _T: + excluded_keys = {"inputSchema"} # inputSchema contains user-provided schema if isinstance(obj, list): return cast(_T, [_snake_to_camel_keys(e) for e in obj]) elif isinstance(obj, dict): - return cast( - _T, {_snake_to_camel(k): _snake_to_camel_keys(v) for k, v in obj.items()} - ) + _dict = {} + for k, v in obj.items(): + if k in excluded_keys: + _dict[k] = v + else: + _dict[_snake_to_camel(k)] = _snake_to_camel_keys(v) + return cast(_T, _dict) else: return obj From 2c206cb84d4bdada9c5055aafd3088356ec42c4c Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 16 Jul 2024 14:59:16 -0400 Subject: [PATCH 3/3] pass in excluded_keys --- .../chat_models/bedrock_converse.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index f99e7ac6..f22eb420 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -403,7 +403,9 @@ def _generate( ) -> ChatResult: """Top Level call""" bedrock_messages, system = _messages_to_bedrock(messages) - params = self._converse_params(stop=stop, **_snake_to_camel_keys(kwargs)) + params = self._converse_params( + stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"}) + ) response = self.client.converse( messages=bedrock_messages, system=system, **params ) @@ -418,7 +420,9 @@ def _stream( **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: bedrock_messages, system = _messages_to_bedrock(messages) - params = self._converse_params(stop=stop, **_snake_to_camel_keys(kwargs)) + params = self._converse_params( + stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"}) + ) response = self.client.converse_stream( messages=bedrock_messages, system=system, **params ) @@ -814,17 +818,20 @@ def _camel_to_snake_keys(obj: _T) -> _T: return obj -def _snake_to_camel_keys(obj: _T) -> _T: - excluded_keys = {"inputSchema"} # inputSchema contains user-provided schema +def _snake_to_camel_keys(obj: _T, excluded_keys: set = set()) -> _T: if isinstance(obj, list): - return cast(_T, [_snake_to_camel_keys(e) for e in obj]) + return cast( + _T, [_snake_to_camel_keys(e, excluded_keys=excluded_keys) for e in obj] + ) elif isinstance(obj, dict): _dict = {} for k, v in obj.items(): if k in excluded_keys: _dict[k] = v else: - _dict[_snake_to_camel(k)] = _snake_to_camel_keys(v) + _dict[_snake_to_camel(k)] = _snake_to_camel_keys( + v, excluded_keys=excluded_keys + ) return cast(_T, _dict) else: return obj