Skip to content

Commit

Permalink
Merge pull request #72 from langchain-ai/cc/standard_tests
Browse files Browse the repository at this point in the history
add standard tests
  • Loading branch information
efriis authored Jun 11, 2024
2 parents 46f2a7f + 80d683a commit 1bdbe20
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 2 deletions.
24 changes: 22 additions & 2 deletions libs/aws/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/aws/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pytest-cov = "^4.1.0"
syrupy = "^4.0.2"
pytest-asyncio = "^0.23.2"
langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" }
langchain-standard-tests = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests"}

[tool.poetry.group.codespell]
optional = true
Expand Down
76 changes: 76 additions & 0 deletions libs/aws/tests/integration_tests/chat_models/test_standard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Standard LangChain interface tests"""

from typing import Type

import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests

from langchain_aws.chat_models.bedrock import ChatBedrock


class TestBedrockStandard(ChatModelIntegrationTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatBedrock

@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model_id": "anthropic.claude-3-sonnet-20240229-v1:0",
}

@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)

@pytest.mark.xfail(reason="Not implemented.")
def test_stop_sequence(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_stop_sequence(
chat_model_class,
chat_model_params,
)

@pytest.mark.xfail(reason="Not yet implemented.")
def test_tool_message_histories_string_content(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
super().test_tool_message_histories_string_content(
chat_model_class, chat_model_params, chat_model_has_tool_calling
)

@pytest.mark.xfail(reason="Not yet implemented.")
def test_tool_message_histories_list_content(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
super().test_tool_message_histories_list_content(
chat_model_class, chat_model_params, chat_model_has_tool_calling
)

@pytest.mark.xfail(reason="Not yet implemented.")
def test_structured_few_shot_examples(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
chat_model_has_tool_calling: bool,
) -> None:
super().test_structured_few_shot_examples(
chat_model_class, chat_model_params, chat_model_has_tool_calling
)
44 changes: 44 additions & 0 deletions libs/aws/tests/unit_tests/test_standard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Standard LangChain interface tests"""

from typing import Type

import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests

from langchain_aws.chat_models.bedrock import ChatBedrock


class TestBedrockStandard(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatBedrock

@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model_id": "anthropic.claude-3-sonnet-20240229-v1:0",
"region_name": "us-east-1",
}

@pytest.mark.xfail(reason="Not implemented.")
def test_chat_model_init_api_key(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_chat_model_init_api_key(
chat_model_class,
chat_model_params,
)

@pytest.mark.xfail(reason="Not implemented.")
def test_standard_params(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_standard_params(
chat_model_class,
chat_model_params,
)

0 comments on commit 1bdbe20

Please sign in to comment.