From 35dc5402b757751abedf40f5ba83dd981c17b2c5 Mon Sep 17 00:00:00 2001 From: ccurme Date: Tue, 16 Jul 2024 16:06:40 -0400 Subject: [PATCH] aws[patch]: retain snake case for input schemas (#114) --- .../chat_models/bedrock_converse.py | 24 ++++++++++++++----- .../chat_models/test_bedrock_converse.py | 15 +++++++++++- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index bb303e03..f22eb420 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -403,7 +403,9 @@ def _generate( ) -> ChatResult: """Top Level call""" bedrock_messages, system = _messages_to_bedrock(messages) - params = self._converse_params(stop=stop, **_snake_to_camel_keys(kwargs)) + params = self._converse_params( + stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"}) + ) response = self.client.converse( messages=bedrock_messages, system=system, **params ) @@ -418,7 +420,9 @@ def _stream( **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: bedrock_messages, system = _messages_to_bedrock(messages) - params = self._converse_params(stop=stop, **_snake_to_camel_keys(kwargs)) + params = self._converse_params( + stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"}) + ) response = self.client.converse_stream( messages=bedrock_messages, system=system, **params ) @@ -814,13 +818,21 @@ def _camel_to_snake_keys(obj: _T) -> _T: return obj -def _snake_to_camel_keys(obj: _T) -> _T: +def _snake_to_camel_keys(obj: _T, excluded_keys: set = set()) -> _T: if isinstance(obj, list): - return cast(_T, [_snake_to_camel_keys(e) for e in obj]) - elif isinstance(obj, dict): return cast( - _T, {_snake_to_camel(k): _snake_to_camel_keys(v) for k, v in obj.items()} + _T, [_snake_to_camel_keys(e, excluded_keys=excluded_keys) for e in obj] ) + elif isinstance(obj, dict): + _dict = {} + for k, v in obj.items(): + if k in excluded_keys: + _dict[k] = v + else: + _dict[_snake_to_camel(k)] = _snake_to_camel_keys( + v, excluded_keys=excluded_keys + ) + return cast(_T, _dict) else: return obj diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index c6c23ea4..34246422 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -1,10 +1,11 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Literal, Type import pytest from langchain_core.language_models import BaseChatModel from langchain_core.messages import HumanMessage +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_aws import ChatBedrockConverse @@ -31,6 +32,18 @@ def standard_chat_model_params(self) -> dict: def supports_image_inputs(self) -> bool: return True + def test_structured_output_snake_case(self, model: BaseChatModel) -> None: + class ClassifyQuery(BaseModel): + """Classify a query.""" + + query_type: Literal["cat", "dog"] = Field( + description="Classify a query as related to cats or dogs." + ) + + chat = model.with_structured_output(ClassifyQuery) + for chunk in chat.stream("How big are cats?"): + assert isinstance(chunk, ClassifyQuery) + @pytest.mark.skip(reason="Needs guardrails setup to run.") def test_guardrails() -> None: