Skip to content

Commit

Permalink
cerebras: pydantic compat, release 0.2.0 (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Sep 16, 2024
1 parent d28ad00 commit 0570d36
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 266 deletions.
87 changes: 35 additions & 52 deletions libs/cerebras/langchain_cerebras/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Wrapper around Cerebras' Chat Completions API."""

import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast

import openai
Expand All @@ -11,16 +10,17 @@
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
get_from_dict_or_env,
from_env,
secret_from_env,
)

# We ignore the "unused imports" here since we want to reexport these from this package.
from langchain_openai.chat_models.base import (
BaseChatOpenAI,
)
from pydantic import Field, SecretStr, model_validator
from typing_extensions import Self

CEREBRAS_BASE_URL = "https://api.cerebras.ai/v1/"

Expand Down Expand Up @@ -314,88 +314,71 @@ def _get_ls_params(
default_factory=secret_from_env("CEREBRAS_API_KEY", default=None),
)
"""Automatically inferred from env are `CEREBRAS_API_KEY` if not provided."""
cerebras_api_base: Optional[str] = Field(
default=CEREBRAS_BASE_URL, alias="base_url"
cerebras_api_base: str = Field(
default_factory=from_env("CEREBRAS_API_BASE", default=CEREBRAS_BASE_URL),
alias="base_url",
)

cerebras_proxy: Optional[str] = None
cerebras_proxy: str = Field(default_factory=from_env("CEREBRAS_PROXY", default=""))

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

values["cerebras_api_base"] = os.getenv(
"CEREBRAS_API_BASE", values["cerebras_api_base"]
)

values["cerebras_proxy"] = get_from_dict_or_env(
values, "cerebras_proxy", "CEREBRAS_PROXY", default=""
)

client_params = {
"api_key": (
values["cerebras_api_key"].get_secret_value()
if values["cerebras_api_key"]
self.cerebras_api_key.get_secret_value()
if self.cerebras_api_key
else None
),
# Ensure we always fallback to the Cerebras API url.
"base_url": (
values["cerebras_api_base"]
if values["cerebras_api_base"]
else CEREBRAS_BASE_URL
),
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"base_url": self.cerebras_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}

if values["cerebras_proxy"] and (
values["http_client"] or values["http_async_client"]
):
cerebras_proxy = values["cerebras_proxy"]
http_client = values["http_client"]
http_async_client = values["http_async_client"]
if self.cerebras_proxy and (self.http_client or self.http_async_client):
raise ValueError(
"Cannot specify 'cerebras_proxy' if one of "
"'http_client'/'http_async_client' is already specified. Received:\n"
f"{cerebras_proxy=}\n{http_client=}\n{http_async_client=}"
f"{self.cerebras_proxy=}\n{self.http_client=}\n{self.http_async_client=}"
)
if not values.get("client"):
if values["cerebras_proxy"] and not values["http_client"]:
if not self.client:
if self.cerebras_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
values["http_client"] = httpx.Client(proxy=values["cerebras_proxy"])
sync_specific = {"http_client": values["http_client"]}
values["root_client"] = openai.OpenAI(**client_params, **sync_specific)
values["client"] = values["root_client"].chat.completions
if not values.get("async_client"):
if values["cerebras_proxy"] and not values["http_async_client"]:
self.http_client = httpx.Client(proxy=self.cerebras_proxy)
sync_specific = {"http_client": self.http_client}
self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore
self.client = self.root_client.chat.completions
if not self.async_client:
if self.cerebras_proxy and not self.http_async_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
values["http_async_client"] = httpx.AsyncClient(
proxy=values["cerebras_proxy"]
)
async_specific = {"http_client": values["http_async_client"]}
values["root_async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
self.http_async_client = httpx.AsyncClient(proxy=self.cerebras_proxy)
async_specific = {"http_client": self.http_async_client}
self.root_async_client = openai.AsyncOpenAI(
**client_params, # type: ignore
**async_specific, # type: ignore
)
values["async_client"] = values["root_async_client"].chat.completions
return values
self.async_client = self.root_async_client.chat.completions
return self

# Patch tool calling w/ streaming.
def _stream(
Expand Down
Loading

0 comments on commit 0570d36

Please sign in to comment.