diff --git a/libs/community/langchain_community/chat_models/yandex.py b/libs/community/langchain_community/chat_models/yandex.py index 02aed41650c45..74510b1e20d1a 100644 --- a/libs/community/langchain_community/chat_models/yandex.py +++ b/libs/community/langchain_community/chat_models/yandex.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, cast from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -12,11 +12,12 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, HumanMessage, SystemMessage, ) -from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from tenacity import ( before_sleep_log, retry, @@ -120,11 +121,41 @@ async def _agenerate( message = AIMessage(content=text) return ChatResult(generations=[ChatGeneration(message=message)]) - -def _make_request( - self: ChatYandexGPT, - messages: List[BaseMessage], -) -> str: + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + stream_resp = completion_with_retry(self, messages=messages, stream=True) + for data in stream_resp: + delta = data + chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + async for data in await acompletion_with_retry( + self, messages=messages, stream=True + ): + delta = data + chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + await run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk + + +def _generate_completion( + self: ChatYandexGPT, messages: List[BaseMessage], stream: bool = False +) -> Any: try: import grpc from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value @@ -166,15 +197,19 @@ def _make_request( completion_options=CompletionOptions( temperature=DoubleValue(value=self.temperature), max_tokens=Int64Value(value=self.max_tokens), + stream=stream, ), messages=[Message(**message) for message in message_history], ) stub = TextGenerationServiceStub(channel) res = stub.Completion(request, metadata=self.grpc_metadata) - return list(res)[0].alternatives[0].message.text + # return list(res)[0].alternatives[0].message.text + return res -async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> str: +async def _agenerate_completion( + self: ChatYandexGPT, messages: List[BaseMessage], stream: bool = False +) -> Any: try: import asyncio @@ -219,17 +254,20 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st message_history = _parse_chat_history(messages) operation_api_url = "operation.api.cloud.yandex.net:443" channel_credentials = grpc.ssl_channel_credentials() + async with grpc.aio.secure_channel(self.url, channel_credentials) as channel: request = CompletionRequest( model_uri=self.model_uri, completion_options=CompletionOptions( temperature=DoubleValue(value=self.temperature), max_tokens=Int64Value(value=self.max_tokens), + stream=stream, # Use the stream parameter ), messages=[Message(**message) for message in message_history], ) stub = TextGenerationAsyncServiceStub(channel) operation = await stub.Completion(request, metadata=self.grpc_metadata) + async with grpc.aio.secure_channel( operation_api_url, channel_credentials ) as operation_channel: @@ -242,9 +280,41 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st metadata=self.grpc_metadata, ) - completion_response = CompletionResponse() - operation.response.Unpack(completion_response) - return completion_response.alternatives[0].message.text + completion_response = CompletionResponse() + operation.response.Unpack(completion_response) + + return completion_response + + +def _make_request_invoke( + self: ChatYandexGPT, + messages: List[BaseMessage], +) -> Any: + return ( + list(_generate_completion(self, messages, False))[0] + .alternatives[0] + .message.text + ) + + +def _make_request_stream( + self: ChatYandexGPT, + messages: List[BaseMessage], +) -> Any: + result = _generate_completion(self, messages, True) + for chunk in result: + yield chunk.alternatives[0].message.text + + +async def _amake_request_invoke(llm: ChatYandexGPT, **kwargs: Any) -> Any: + result = await _agenerate_completion(llm, stream=False, **kwargs) + return result.alternatives[0].message.text + + +async def _amake_request_stream(llm: ChatYandexGPT, **kwargs: Any) -> Any: + result = await _agenerate_completion(llm, stream=True, **kwargs) + for alternative in result.alternatives: + yield alternative.message.text def _create_retry_decorator(llm: ChatYandexGPT) -> Callable[[Any], Any]: @@ -261,23 +331,31 @@ def _create_retry_decorator(llm: ChatYandexGPT) -> Callable[[Any], Any]: ) -def completion_with_retry(llm: ChatYandexGPT, **kwargs: Any) -> Any: +def completion_with_retry( + llm: ChatYandexGPT, stream: bool = False, **kwargs: Any +) -> Any: """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator(llm) @retry_decorator def _completion_with_retry(**_kwargs: Any) -> Any: - return _make_request(llm, **_kwargs) + if stream: + return _make_request_stream(llm, **_kwargs) + return _make_request_invoke(llm, **_kwargs) return _completion_with_retry(**kwargs) -async def acompletion_with_retry(llm: ChatYandexGPT, **kwargs: Any) -> Any: +async def acompletion_with_retry( + llm: ChatYandexGPT, stream: bool = False, **kwargs: Any +) -> Any: """Use tenacity to retry the async completion call.""" retry_decorator = _create_retry_decorator(llm) @retry_decorator async def _completion_with_retry(**_kwargs: Any) -> Any: - return await _amake_request(llm, **_kwargs) + if stream: + return _amake_request_stream(llm, **_kwargs) + return await _amake_request_invoke(llm, **_kwargs) return await _completion_with_retry(**kwargs)