Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis committed Aug 27, 2024
1 parent 4ac46b8 commit 77cd167
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
12 changes: 6 additions & 6 deletions libs/cerebras/langchain_cerebras/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
secret_from_env,
)

# We ignore the "unused imports" here since we want to reexport these from this package.
Expand Down Expand Up @@ -309,25 +309,25 @@ def _get_ls_params(

model_name: str = Field(alias="model")
"""Model name to use."""
cerebras_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
cerebras_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("CEREBRAS_API_KEY", default=None),
)
"""Automatically inferred from env are `CEREBRAS_API_KEY` if not provided."""
cerebras_api_base: Optional[str] = Field(
default=CEREBRAS_BASE_URL, alias="base_url"
)

cerebras_proxy: Optional[str] = None

@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")

values["cerebras_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "cerebras_api_key", "CEREBRAS_API_KEY", "")
)
values["cerebras_api_base"] = os.getenv(
"CEREBRAS_API_BASE", values["cerebras_api_base"]
)
Expand Down
3 changes: 3 additions & 0 deletions libs/cerebras/tests/unit_tests/test_base_standard.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Standard LangChain interface tests"""

import os
from typing import Type

from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ChatModelUnitTests

from langchain_cerebras import ChatCerebras

os.environ["CEREBRAS_API_KEY"] = "foo"


class TestCerebrasStandard(ChatModelUnitTests):
@property
Expand Down

0 comments on commit 77cd167

Please sign in to comment.