Skip to content

Commit

Permalink
add timeout (#48)
Browse files Browse the repository at this point in the history
Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Mar 6, 2024
1 parent 40c294d commit bad31e1
Show file tree
Hide file tree
Showing 23 changed files with 68 additions and 48 deletions.
8 changes: 8 additions & 0 deletions generate/chat_completion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,35 @@ def _get_request_parameters(self, prompt: Prompt, stream: bool = False, **kwargs

@override
def generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput:
timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None
request_parameters = self._get_request_parameters(prompt, **kwargs)
request_parameters['timeout'] = timeout
response = self.http_client.post(request_parameters=request_parameters)
return self._process_reponse(response.json())

@override
async def async_generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput:
timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None
request_parameters = self._get_request_parameters(prompt, **kwargs)
request_parameters['timeout'] = timeout
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._process_reponse(response.json())

@override
def stream_generate(self, prompt: Prompt, **kwargs: Any) -> Iterator[ChatCompletionStreamOutput]:
timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None
request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs)
request_parameters['timeout'] = timeout
stream_manager = StreamManager(info=self.model_info)
for line in self.http_client.stream_post(request_parameters=request_parameters):
if output := self._process_stream_line(line, stream_manager):
yield output

@override
async def async_stream_generate(self, prompt: Prompt, **kwargs: Any) -> AsyncIterator[ChatCompletionStreamOutput]:
timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None
request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs)
request_parameters['timeout'] = timeout
stream_manager = StreamManager(info=self.model_info)
async for line in self.http_client.async_stream_post(request_parameters=request_parameters):
if output := self._process_stream_line(line, stream_manager):
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.stream_manager import StreamManager
from generate.http import HttpClient, HttpxPostKwargs
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms import AnthropicSettings
from generate.types import Probability, Temperature

Expand All @@ -44,7 +44,7 @@ class AnthropicChatParameters(ModelParameters):
top_k: Optional[PositiveInt] = None


class AnthropicParametersDict(ModelParametersDict, total=False):
class AnthropicParametersDict(RemoteModelParametersDict, total=False):
system: Optional[str]
max_tokens: PositiveInt
metadata: Optional[Dict[str, Any]]
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.baichuan import BaichuanSettings
from generate.types import Probability, Temperature

Expand All @@ -43,7 +43,7 @@ class BaichuanChatParameters(ModelParameters):
search: Optional[bool] = Field(default=None, alias='with_search_enhance')


class BaichuanChatParametersDict(ModelParametersDict, total=False):
class BaichuanChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_k: Optional[int]
top_p: Optional[Probability]
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/bailian.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.bailian import BailianSettings, BailianTokenManager
from generate.types import Probability

Expand Down Expand Up @@ -67,7 +67,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class BailianChatParametersDict(ModelParametersDict, total=False):
class BailianChatParametersDict(RemoteModelParametersDict, total=False):
request_id: str
top_p: Optional[Probability]
top_k: Optional[int]
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/dashscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
HttpxPostKwargs,
ResponseValue,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.dashscope import DashScopeSettings
from generate.types import Probability

Expand All @@ -44,7 +44,7 @@ class DashScopeChatParameters(ModelParameters):
search: Annotated[Optional[bool], Field(alias='enable_search')] = None


class DashScopeChatParametersDict(ModelParametersDict, total=False):
class DashScopeChatParametersDict(RemoteModelParametersDict, total=False):
seed: Optional[PositiveInt]
max_tokens: Optional[PositiveInt]
top_p: Optional[Probability]
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/dashscope_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
HttpxPostKwargs,
ResponseValue,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.dashscope import DashScopeSettings
from generate.types import Probability

Expand All @@ -41,7 +41,7 @@ class DashScopeMultiModalChatParameters(ModelParameters):
top_k: Optional[Annotated[int, Field(ge=0, le=100)]] = None


class DashScopeMultiModalChatParametersDict(ModelParametersDict, total=False):
class DashScopeMultiModalChatParametersDict(RemoteModelParametersDict, total=False):
seed: Optional[PositiveInt]
top_p: Optional[Probability]
top_k: Optional[int]
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.models.openai_like import OpenAILikeChat
from generate.http import HttpClient
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms import DeepSeekSettings
from generate.types import Probability

Expand All @@ -23,7 +23,7 @@ class DeepSeekChatParameters(ModelParameters):
stop: Optional[Union[str, List[str]]] = None


class DeepSeekParametersDict(ModelParametersDict, total=False):
class DeepSeekParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[float]
top_p: Optional[Probability]
max_tokens: Optional[PositiveInt]
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.hunyuan import HunyuanSettings
from generate.types import Probability, Temperature

Expand All @@ -44,7 +44,7 @@ class HunyuanChatParameters(ModelParameters):
top_p: Optional[Probability] = None


class HunyuanChatParametersDict(ModelParametersDict, total=False):
class HunyuanChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_p: Optional[Probability]

Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.minimax import MinimaxSettings
from generate.types import Probability, Temperature

Expand Down Expand Up @@ -62,7 +62,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class MinimaxChatParametersDict(ModelParametersDict, total=False):
class MinimaxChatParametersDict(RemoteModelParametersDict, total=False):
system_prompt: str
role_meta: RoleMeta
beam_width: Optional[int]
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/minimax_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelInfo, ModelParameters, ModelParametersDict
from generate.model import ModelInfo, ModelParameters, RemoteModelParametersDict
from generate.platforms.minimax import MinimaxSettings
from generate.types import OrIterable, Probability, Temperature
from generate.utils import ensure_iterable
Expand Down Expand Up @@ -118,7 +118,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class MinimaxProChatParametersDict(ModelParametersDict, total=False):
class MinimaxProChatParametersDict(RemoteModelParametersDict, total=False):
reply_constraints: ReplyConstrainsDict
bot_setting: List[BotSettingDict]
temperature: Optional[Temperature]
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/moonshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.models.openai_like import OpenAILikeChat
from generate.http import HttpClient
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms import MoonshotSettings
from generate.types import Probability, Temperature

Expand All @@ -20,7 +20,7 @@ class MoonshotChatParameters(ModelParameters):
max_tokens: Optional[PositiveInt] = None


class MoonshotParametersDict(ModelParametersDict, total=False):
class MoonshotParametersDict(RemoteModelParametersDict, total=False):
temperature: Temperature
top_p: Probability
max_tokens: PositiveInt
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from generate.http import (
HttpClient,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.openai import OpenAISettings
from generate.types import OrIterable, Probability, Temperature
from generate.utils import ensure_iterable
Expand All @@ -42,7 +42,7 @@ class OpenAIChatParameters(ModelParameters):
tool_choice: Union[Literal['auto'], OpenAIToolChoice, None] = None


class OpenAIChatParametersDict(ModelParametersDict, total=False):
class OpenAIChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_p: Optional[Probability]
max_tokens: Optional[PositiveInt]
Expand Down
14 changes: 8 additions & 6 deletions generate/chat_completion/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
)
from generate.chat_completion.message import AssistantMessage, Prompt, ensure_messages
from generate.chat_completion.model_output import Stream
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict


class FakeChatParameters(ModelParameters):
prefix: str = 'Completed:'


class FakeChatParametersDict(ModelParametersDict, total=False):
class FakeChatParametersDict(RemoteModelParametersDict, total=False):
prefix: str


Expand All @@ -28,19 +28,21 @@ class FakeChat(ChatCompletionModel):
def __init__(self, parameters: FakeChatParameters | None = None) -> None:
self.parameters = parameters or FakeChatParameters()

def generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> ChatCompletionOutput:
def generate(self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.clone_with_changes(**kwargs)
content = f'{parameters.prefix}{messages[-1].content}'
return ChatCompletionOutput(model_info=self.model_info, message=AssistantMessage(content=content))

async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> ChatCompletionOutput:
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.clone_with_changes(**kwargs)
content = f'{parameters.prefix}{messages[-1].content}'
return ChatCompletionOutput(model_info=self.model_info, message=AssistantMessage(content=content))

def stream_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> Iterator[ChatCompletionStreamOutput]:
def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.clone_with_changes(**kwargs)
content = f'{parameters.prefix}{messages[-1].content}'
Expand All @@ -56,7 +58,7 @@ def stream_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict])
)

async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]
self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.clone_with_changes(**kwargs)
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/wenxin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.baidu import QianfanSettings, QianfanTokenManager
from generate.types import JsonSchema, OrIterable, Probability, Temperature
from generate.utils import ensure_iterable
Expand Down Expand Up @@ -124,7 +124,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class WenxinChatParametersDict(ModelParametersDict, total=False):
class WenxinChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_p: Optional[Probability]
functions: Optional[List[WenxinFunction]]
Expand Down
4 changes: 2 additions & 2 deletions generate/chat_completion/models/yi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.models.openai_like import OpenAILikeChat
from generate.http import HttpClient
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms import YiSettings


Expand All @@ -18,7 +18,7 @@ class YiChatParameters(ModelParameters):
max_tokens: Optional[PositiveInt] = None


class YiParametersDict(ModelParametersDict, total=False):
class YiParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Annotated[float, Field(ge=0, lt=2)]]
max_tokens: Optional[PositiveInt]

Expand Down
6 changes: 3 additions & 3 deletions generate/chat_completion/models/zhipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelInfo, ModelParameters, ModelParametersDict
from generate.model import ModelInfo, ModelParameters, RemoteModelParametersDict
from generate.platforms.zhipu import ZhipuSettings, generate_zhipu_token
from generate.types import JsonSchema, Probability, Temperature

Expand Down Expand Up @@ -92,7 +92,7 @@ def can_not_equal_zero(cls, v: Optional[Temperature]) -> Optional[Temperature]:
return v


class ZhipuChatParametersDict(ModelParametersDict, total=False):
class ZhipuChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_p: Optional[Probability]
request_id: Optional[str]
Expand Down Expand Up @@ -448,7 +448,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class ZhipuCharacterChatParametersDict(ModelParametersDict, total=False):
class ZhipuCharacterChatParametersDict(RemoteModelParametersDict, total=False):
meta: ZhipuMeta
request_id: Optional[str]

Expand Down
4 changes: 2 additions & 2 deletions generate/image_generation/models/baidu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ImageGenerationOutput,
RemoteImageGenerationModel,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.baidu import BaiduCreationSettings, BaiduCreationTokenManager

ValidSize = Literal[
Expand Down Expand Up @@ -52,7 +52,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output_data


class BaiduImageGenerationParametersDict(ModelParametersDict, total=False):
class BaiduImageGenerationParametersDict(RemoteModelParametersDict, total=False):
size: ValidSize
n: Optional[int]
reference_image: Union[HttpUrl, Base64Str, None]
Expand Down
4 changes: 2 additions & 2 deletions generate/image_generation/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from generate.http import HttpClient, HttpxPostKwargs
from generate.image_generation.base import GeneratedImage, ImageGenerationOutput, RemoteImageGenerationModel
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.openai import OpenAISettings

MAX_PROMPT_LENGTH_DALLE_3 = 4000
Expand Down Expand Up @@ -46,7 +46,7 @@ class OpenAIImageGenerationParameters(ModelParameters):
user: Optional[str] = None


class OpenAIImageGenerationParametersDict(ModelParametersDict, total=False):
class OpenAIImageGenerationParametersDict(RemoteModelParametersDict, total=False):
quality: Optional[Literal['hd', 'standard']]
response_format: Optional[Literal['url', 'b64_json']]
size: Optional[Literal['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']]
Expand Down
Loading

0 comments on commit bad31e1

Please sign in to comment.