Skip to content

Commit

Permalink
update type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
WHALEEYE committed Aug 31, 2024
1 parent 440f46e commit 48b6f4e
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions crab/agents/backend_models/camel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import base64
import io
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any

from PIL import Image

Expand Down Expand Up @@ -44,7 +44,7 @@ def find_model_platform_type(model_platform_name: str) -> ModelPlatformType:
)


def find_model_type(model_name: str) -> Union[ModelType, str]:
def find_model_type(model_name: str) -> ModelType | str:
for model in ModelType:
if model.value.lower() == model_name.lower():
return model
Expand All @@ -61,7 +61,7 @@ def __init__(
self,
model: str,
model_platform: str,
parameters: Optional[Dict[str, Any]] = None,
parameters: dict[str, Any] | None = None,
history_messages_len: int = 0,
) -> None:
if not CAMEL_ENABLED:
Expand All @@ -70,7 +70,7 @@ def __init__(
# TODO: a better way?
self.model_type = find_model_type(model)
self.model_platform_type = find_model_platform_type(model_platform)
self.client: Optional[ExternalToolAgent] = None
self.client: ExternalToolAgent | None = None
self.token_usage = 0

super().__init__(
Expand All @@ -82,7 +82,7 @@ def __init__(
def get_token_usage(self):
return self.token_usage

def reset(self, system_message: str, action_space: Optional[List[Action]]) -> None:
def reset(self, system_message: str, action_space: list[Action] | None) -> None:
action_schema = self._convert_action_to_schema(action_space)
config = self.parameters.copy()
if action_schema is not None:
Expand Down Expand Up @@ -111,14 +111,14 @@ def reset(self, system_message: str, action_space: Optional[List[Action]]) -> No

@staticmethod
def _convert_action_to_schema(
action_space: Optional[List[Action]],
) -> Optional[List[OpenAIFunction]]:
action_space: list[Action] | None,
) -> list[OpenAIFunction] | None:
if action_space is None:
return None
return [OpenAIFunction(action.entry) for action in action_space]

@staticmethod
def _convert_tool_calls_to_action_list(tool_calls) -> List[ActionOutput]:
def _convert_tool_calls_to_action_list(tool_calls) -> list[ActionOutput]:
if tool_calls is None:
return tool_calls

Expand All @@ -130,9 +130,9 @@ def _convert_tool_calls_to_action_list(tool_calls) -> List[ActionOutput]:
for call in tool_calls
]

def chat(self, messages: List[Tuple[str, MessageType]]):
def chat(self, messages: list[tuple[str, MessageType]]):
# TODO: handle multiple text messages after message refactoring
image_list: List[Image.Image] = []
image_list: list[Image.Image] = []
content = ""
for message in messages:
if message[1] == MessageType.IMAGE_JPG_BASE64:
Expand Down

0 comments on commit 48b6f4e

Please sign in to comment.