Skip to content

Commit

Permalink
refactor camel model
Browse files Browse the repository at this point in the history
  • Loading branch information
WHALEEYE committed Aug 31, 2024
1 parent 48b6f4e commit cedaca7
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 217 deletions.
51 changes: 22 additions & 29 deletions crab/agents/backend_models/camel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import base64
import io
import json
from typing import Any

Expand All @@ -34,28 +32,6 @@
CAMEL_ENABLED = False


def find_model_platform_type(model_platform_name: str) -> ModelPlatformType:
for platform in ModelPlatformType:
if platform.value.lower() == model_platform_name.lower():
return platform
all_models = [platform.value for platform in ModelPlatformType]
raise ValueError(
f"Model {model_platform_name} not found. Supported models are {all_models}"
)


def find_model_type(model_name: str) -> ModelType | str:
for model in ModelType:
if model.value.lower() == model_name.lower():
return model
return model_name


def decode_image(encoded_image: str) -> Image:
data = base64.b64decode(encoded_image)
return Image.open(io.BytesIO(data))


class CamelModel(BackendModel):
def __init__(
self,
Expand All @@ -68,8 +44,8 @@ def __init__(
raise ImportError("Please install camel-ai to use CamelModel")
self.parameters = parameters or {}
# TODO: a better way?
self.model_type = find_model_type(model)
self.model_platform_type = find_model_platform_type(model_platform)
self.model_type = self.find_model_type(model)
self.model_platform_type = self.find_model_platform_type(model_platform)
self.client: ExternalToolAgent | None = None
self.token_usage = 0

Expand Down Expand Up @@ -109,10 +85,27 @@ def reset(self, system_message: str, action_space: list[Action] | None) -> None:
)
self.token_usage = 0

@staticmethod
def find_model_platform_type(model_platform_name: str) -> "ModelPlatformType":
for platform in ModelPlatformType:
if platform.value.lower() == model_platform_name.lower():
return platform
all_models = [platform.value for platform in ModelPlatformType]
raise ValueError(
f"Model {model_platform_name} not found. Supported models are {all_models}"
)

@staticmethod
def find_model_type(model_name: str) -> "str | ModelType":
for model in ModelType:
if model.value.lower() == model_name.lower():
return model
return model_name

@staticmethod
def _convert_action_to_schema(
action_space: list[Action] | None,
) -> list[OpenAIFunction] | None:
) -> "list[OpenAIFunction] | None":
if action_space is None:
return None
return [OpenAIFunction(action.entry) for action in action_space]
Expand Down Expand Up @@ -147,12 +140,12 @@ def chat(self, messages: list[tuple[str, MessageType]]):
)
response = self.client.step(usermsg)
self.token_usage += response.info["usage"]["total_tokens"]
tool_calls = response.info.get("tool_call_requests")
tool_call_request = response.info.get("tool_call_request")

# TODO: delete this after record_message is refactored
self.client.record_message(response.msg)

return BackendOutput(
message=response.msg.content,
action_list=self._convert_tool_calls_to_action_list(tool_calls),
action_list=self._convert_tool_calls_to_action_list([tool_call_request]),
)
Loading

0 comments on commit cedaca7

Please sign in to comment.