Skip to content

Commit

Permalink
aws[patch]: retain snake case for input schemas (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme authored Jul 16, 2024
1 parent 6f1fec9 commit 35dc540
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
24 changes: 18 additions & 6 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,9 @@ def _generate(
) -> ChatResult:
"""Top Level call"""
bedrock_messages, system = _messages_to_bedrock(messages)
params = self._converse_params(stop=stop, **_snake_to_camel_keys(kwargs))
params = self._converse_params(
stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"})
)
response = self.client.converse(
messages=bedrock_messages, system=system, **params
)
Expand All @@ -418,7 +420,9 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
bedrock_messages, system = _messages_to_bedrock(messages)
params = self._converse_params(stop=stop, **_snake_to_camel_keys(kwargs))
params = self._converse_params(
stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"})
)
response = self.client.converse_stream(
messages=bedrock_messages, system=system, **params
)
Expand Down Expand Up @@ -814,13 +818,21 @@ def _camel_to_snake_keys(obj: _T) -> _T:
return obj


def _snake_to_camel_keys(obj: _T) -> _T:
def _snake_to_camel_keys(obj: _T, excluded_keys: set = set()) -> _T:
if isinstance(obj, list):
return cast(_T, [_snake_to_camel_keys(e) for e in obj])
elif isinstance(obj, dict):
return cast(
_T, {_snake_to_camel(k): _snake_to_camel_keys(v) for k, v in obj.items()}
_T, [_snake_to_camel_keys(e, excluded_keys=excluded_keys) for e in obj]
)
elif isinstance(obj, dict):
_dict = {}
for k, v in obj.items():
if k in excluded_keys:
_dict[k] = v
else:
_dict[_snake_to_camel(k)] = _snake_to_camel_keys(
v, excluded_keys=excluded_keys
)
return cast(_T, _dict)
else:
return obj

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Standard LangChain interface tests"""

from typing import Type
from typing import Literal, Type

import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests

from langchain_aws import ChatBedrockConverse
Expand All @@ -31,6 +32,18 @@ def standard_chat_model_params(self) -> dict:
def supports_image_inputs(self) -> bool:
return True

def test_structured_output_snake_case(self, model: BaseChatModel) -> None:
class ClassifyQuery(BaseModel):
"""Classify a query."""

query_type: Literal["cat", "dog"] = Field(
description="Classify a query as related to cats or dogs."
)

chat = model.with_structured_output(ClassifyQuery)
for chunk in chat.stream("How big are cats?"):
assert isinstance(chunk, ClassifyQuery)


@pytest.mark.skip(reason="Needs guardrails setup to run.")
def test_guardrails() -> None:
Expand Down

0 comments on commit 35dc540

Please sign in to comment.