From 4ac46b80deab6d02d1101df4f2e4a2ebb66887ab Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Mon, 26 Aug 2024 16:41:04 -0700 Subject: [PATCH] x --- libs/cerebras/.gitignore | 1 + .../langchain_cerebras/chat_models.py | 35 ++++++++----------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/libs/cerebras/.gitignore b/libs/cerebras/.gitignore index bee8a64..3883e04 100644 --- a/libs/cerebras/.gitignore +++ b/libs/cerebras/.gitignore @@ -1 +1,2 @@ __pycache__ +.mypy_cache diff --git a/libs/cerebras/langchain_cerebras/chat_models.py b/libs/cerebras/langchain_cerebras/chat_models.py index 257bcda..441b644 100644 --- a/libs/cerebras/langchain_cerebras/chat_models.py +++ b/libs/cerebras/langchain_cerebras/chat_models.py @@ -1,14 +1,7 @@ """Wrapper around Cerebras' Chat Completions API.""" import os -from typing import ( - Any, - AsyncIterator, - Dict, - Iterator, - List, - Optional, -) +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast import openai from langchain_core.callbacks import ( @@ -25,17 +18,14 @@ ) # We ignore the "unused imports" here since we want to reexport these from this package. -from langchain_openai.chat_models.base import ( # pyright: ignore # noqa F401 - ChatOpenAI, - _convert_dict_to_message, - _convert_message_to_dict, - _format_message_content, +from langchain_openai.chat_models.base import ( + BaseChatOpenAI, ) CEREBRAS_BASE_URL = "https://api.cerebras.ai/v1/" -class ChatCerebras(ChatOpenAI): +class ChatCerebras(BaseChatOpenAI): r"""ChatCerebras chat model. Setup: @@ -416,11 +406,12 @@ def _stream( **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: if kwargs.get("tools"): - return [ - super()._generate(messages, stop, run_manager, **kwargs).generations[0] - ] + yield cast( + ChatGenerationChunk, + super()._generate(messages, stop, run_manager, **kwargs).generations[0], + ) else: - return super()._stream(messages, stop, run_manager, **kwargs) + yield from super()._stream(messages, stop, run_manager, **kwargs) async def _astream( self, @@ -430,9 +421,13 @@ async def _astream( **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: if kwargs.get("tools"): + generation = await super()._agenerate(messages, stop, run_manager, **kwargs) yield ( - await super()._agenerate(messages, stop, run_manager, **kwargs) - ).generations[0] + cast( + ChatGenerationChunk, + generation.generations[0], + ) + ) else: async for msg in super()._astream(messages, stop, run_manager, **kwargs): yield msg