Skip to content

Commit

Permalink
feat: add gemini model families, enhance group chat selection for Gem…
Browse files Browse the repository at this point in the history
…ini model and add tests (#5334)

Resolves #5322
  • Loading branch information
ekzhu authored Feb 3, 2025
1 parent 9af6883 commit 569bc19
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Dict, List, Mapping, Sequence

from autogen_core import Component, ComponentModel
from autogen_core.models import ChatCompletionClient, SystemMessage
from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage
from pydantic import BaseModel
from typing_extensions import Self

Expand Down Expand Up @@ -135,7 +135,11 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
select_speaker_prompt = self._selector_prompt.format(
roles=roles, participants=str(participants), history=history
)
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
select_speaker_messages: List[SystemMessage | UserMessage]
if self._model_client.model_info["family"].startswith("gemini"):
select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")]
else:
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os

import pytest
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console
from autogen_core.models import ModelFamily
from autogen_ext.models.openai import OpenAIChatCompletionClient


@pytest.mark.asyncio
async def test_selector_group_chat_gemini() -> None:
try:
api_key = os.environ["GEMINI_API_KEY"]
except KeyError:
pytest.skip("GEMINI_API_KEY not set in environment variables.")

model_client = OpenAIChatCompletionClient(
model="gemini-1.5-flash",
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
api_key=api_key,
model_info={
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_1_5_FLASH,
},
)

assistant = AssistantAgent(
"assistant",
description="A helpful assistant agent.",
model_client=model_client,
system_message="You are a helpful assistant.",
)

critic = AssistantAgent(
"critic",
description="A critic agent to provide feedback.",
model_client=model_client,
system_message="Provide feedback.",
)

team = SelectorGroupChat([assistant, critic], model_client=model_client, max_turns=2)
await Console(team.run_stream(task="Draft a short email about organizing a holiday party for new year."))
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@
"source": [
"import os\n",
"\n",
"from autogen_core.models import UserMessage\n",
"from autogen_core.models import ModelFamily, UserMessage\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"model_client = OpenAIChatCompletionClient(\n",
Expand All @@ -320,7 +320,7 @@
" \"vision\": True,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" \"family\": \"unknown\",\n",
" \"family\": ModelFamily.GEMINI_1_5_FLASH,\n",
" },\n",
")\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@
"source": [
"import os\n",
"\n",
"from autogen_core.models import UserMessage\n",
"from autogen_core.models import ModelFamily, UserMessage\n",
"from autogen_ext.models.openai import OpenAIChatCompletionClient\n",
"\n",
"model_client = OpenAIChatCompletionClient(\n",
Expand All @@ -328,7 +328,7 @@
" \"vision\": True,\n",
" \"function_calling\": True,\n",
" \"json_output\": True,\n",
" \"family\": \"unknown\",\n",
" \"family\": ModelFamily.GEMINI_1_5_FLASH,\n",
" },\n",
")\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,23 @@ class ModelFamily:
GPT_4 = "gpt-4"
GPT_35 = "gpt-35"
R1 = "r1"
GEMINI_1_5_FLASH = "gemini-1.5-flash"
GEMINI_1_5_PRO = "gemini-1.5-pro"
GEMINI_2_0_FLASH = "gemini-2.0-flash"
UNKNOWN = "unknown"

ANY: TypeAlias = Literal["gpt-4o", "o1", "o3", "gpt-4", "gpt-35", "r1", "unknown"]
ANY: TypeAlias = Literal[
"gpt-4o",
"o1",
"o3",
"gpt-4",
"gpt-35",
"r1",
"gemini-1.5-flash",
"gemini-1.5-pro",
"gemini-2.0-flash",
"unknown",
]

def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,17 @@ async def main() -> None:
temperature=0.2,
)
model_client = SKChatCompletionAdapter(sk_client, kernel=Kernel(memory=NullMemory()), prompt_settings=settings)
model_client = SKChatCompletionAdapter(
sk_client,
kernel=Kernel(memory=NullMemory()),
prompt_settings=settings,
model_info={
"family": "gemini-1.5-flash",
"function_calling": True,
"json_output": True,
"vision": False,
},
)
# Call the model directly.
model_result = await model_client.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ async def test_gemini() -> None:
"function_calling": True,
"json_output": True,
"vision": True,
"family": ModelFamily.UNKNOWN,
"family": ModelFamily.GEMINI_1_5_FLASH,
},
)
await _test_model_client_basic_completion(model_client)
Expand Down

0 comments on commit 569bc19

Please sign in to comment.