Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis committed Aug 26, 2024
1 parent eff9353 commit 4ac46b8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
1 change: 1 addition & 0 deletions libs/cerebras/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__pycache__
.mypy_cache
35 changes: 15 additions & 20 deletions libs/cerebras/langchain_cerebras/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
"""Wrapper around Cerebras' Chat Completions API."""

import os
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
)
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast

import openai
from langchain_core.callbacks import (
Expand All @@ -25,17 +18,14 @@
)

# We ignore the "unused imports" here since we want to reexport these from this package.
from langchain_openai.chat_models.base import ( # pyright: ignore # noqa F401
ChatOpenAI,
_convert_dict_to_message,
_convert_message_to_dict,
_format_message_content,
from langchain_openai.chat_models.base import (
BaseChatOpenAI,
)

CEREBRAS_BASE_URL = "https://api.cerebras.ai/v1/"


class ChatCerebras(ChatOpenAI):
class ChatCerebras(BaseChatOpenAI):
r"""ChatCerebras chat model.
Setup:
Expand Down Expand Up @@ -416,11 +406,12 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if kwargs.get("tools"):
return [
super()._generate(messages, stop, run_manager, **kwargs).generations[0]
]
yield cast(
ChatGenerationChunk,
super()._generate(messages, stop, run_manager, **kwargs).generations[0],
)
else:
return super()._stream(messages, stop, run_manager, **kwargs)
yield from super()._stream(messages, stop, run_manager, **kwargs)

async def _astream(
self,
Expand All @@ -430,9 +421,13 @@ async def _astream(
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
if kwargs.get("tools"):
generation = await super()._agenerate(messages, stop, run_manager, **kwargs)
yield (
await super()._agenerate(messages, stop, run_manager, **kwargs)
).generations[0]
cast(
ChatGenerationChunk,
generation.generations[0],
)
)
else:
async for msg in super()._astream(messages, stop, run_manager, **kwargs):
yield msg

0 comments on commit 4ac46b8

Please sign in to comment.