Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: added stream and astream to chatyandexgpt #25483

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
109 changes: 92 additions & 17 deletions libs/community/langchain_community/chat_models/yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from typing import Any, Callable, Dict, List, Optional, cast
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, cast

from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
Expand All @@ -12,11 +12,12 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from tenacity import (
before_sleep_log,
retry,
Expand Down Expand Up @@ -120,11 +121,45 @@ async def _agenerate(
message = AIMessage(content=text)
return ChatResult(generations=[ChatGeneration(message=message)])


def _make_request(
self: ChatYandexGPT,
messages: List[BaseMessage],
) -> str:
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
stream_resp = completion_with_retry(self, messages=messages, stream=True)
current_text = ""
for data in stream_resp:
delta = data[len(current_text) :]
olgamurraft marked this conversation as resolved.
Show resolved Hide resolved
current_text = data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
current_text = ""
async for data in await acompletion_with_retry(
self, messages=messages, stream=True
):
delta = data[len(current_text) :]
current_text = data
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk


def _generate_completion(
self: ChatYandexGPT, messages: List[BaseMessage], stream: bool = None
):
try:
import grpc
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
Expand Down Expand Up @@ -166,15 +201,15 @@ def _make_request(
completion_options=CompletionOptions(
temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens),
stream=stream,
),
messages=[Message(**message) for message in message_history],
)
stub = TextGenerationServiceStub(channel)
res = stub.Completion(request, metadata=self._grpc_metadata)
return list(res)[0].alternatives[0].message.text
return stub.Completion(request, metadata=self._grpc_metadata)


async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> str:
async def _agenerate_completion(self, messages, stream=False):
try:
import asyncio

Expand Down Expand Up @@ -219,17 +254,20 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
message_history = _parse_chat_history(messages)
operation_api_url = "operation.api.cloud.yandex.net:443"
channel_credentials = grpc.ssl_channel_credentials()

async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
request = CompletionRequest(
model_uri=self.model_uri,
completion_options=CompletionOptions(
temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens),
stream=stream, # Use the stream parameter
),
messages=[Message(**message) for message in message_history],
)
stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata)

async with grpc.aio.secure_channel(
operation_api_url, channel_credentials
) as operation_channel:
Expand All @@ -242,9 +280,38 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
metadata=self._grpc_metadata,
)

completion_response = CompletionResponse()
operation.response.Unpack(completion_response)
return completion_response.alternatives[0].message.text
completion_response = CompletionResponse()
operation.response.Unpack(completion_response)

return completion_response


def _make_request_invoke(
self: ChatYandexGPT,
messages: List[BaseMessage],
):
result = _generate_completion(self, messages, None)
return list(result)[0].alternatives[0].message.text


def _make_request_stream(
self: ChatYandexGPT,
messages: List[BaseMessage],
):
result = _generate_completion(self, messages, True)
for chunk in result:
yield chunk.alternatives[0].message.text


async def _amake_request_invoke(llm: ChatYandexGPT, **kwargs: Any) -> Any:
result = await _agenerate_completion(llm, stream=None, **kwargs)
return result.alternatives[0].message.text


async def _amake_request_stream(llm: ChatYandexGPT, **kwargs: Any) -> Any:
result = await _agenerate_completion(llm, stream=True, **kwargs)
for alternative in result.alternatives:
yield alternative.message.text


def _create_retry_decorator(llm: ChatYandexGPT) -> Callable[[Any], Any]:
Expand All @@ -261,23 +328,31 @@ def _create_retry_decorator(llm: ChatYandexGPT) -> Callable[[Any], Any]:
)


def completion_with_retry(llm: ChatYandexGPT, **kwargs: Any) -> Any:
def completion_with_retry(
llm: ChatYandexGPT, stream: bool = False, **kwargs: Any
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)

@retry_decorator
def _completion_with_retry(**_kwargs: Any) -> Any:
return _make_request(llm, **_kwargs)
if stream:
return _make_request_stream(llm, **_kwargs)
return _make_request_invoke(llm, **_kwargs)

return _completion_with_retry(**kwargs)


async def acompletion_with_retry(llm: ChatYandexGPT, **kwargs: Any) -> Any:
async def acompletion_with_retry(
llm: ChatYandexGPT, stream: bool = False, **kwargs: Any
) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)

@retry_decorator
async def _completion_with_retry(**_kwargs: Any) -> Any:
return await _amake_request(llm, **_kwargs)
if stream:
return _amake_request_stream(llm, **_kwargs)
return await _amake_request_invoke(llm, **_kwargs)

return await _completion_with_retry(**kwargs)
Loading