diff --git a/autogen/oai/github.py b/autogen/oai/github.py index 31c0d07b1e74..355f48af554e 100644 --- a/autogen/oai/github.py +++ b/autogen/oai/github.py @@ -1,4 +1,4 @@ -'''Create a Github LLM Client with Azure Fallback. +"""Create a Github LLM Client with Azure Fallback. # Usage example: if __name__ == "__main__": @@ -7,9 +7,9 @@ "system_prompt": "You are a knowledgeable history teacher.", "use_azure_fallback": True } - + wrapper = GithubWrapper(config_list=[config]) - + response = wrapper.create(messages=[{"role": "user", "content": "What is the capital of France?"}]) print(wrapper.message_retrieval(response)[0]) @@ -18,17 +18,16 @@ {"role": "assistant", "content": "The French Revolution was a period of major social and political upheaval in France that began in 1789 with the Storming of the Bastille and ended in the late 1790s with the ascent of Napoleon Bonaparte."}, {"role": "user", "content": "What were the main causes?"} ] - + response = wrapper.create(messages=conversation) print(wrapper.message_retrieval(response)[0]) -''' +""" from __future__ import annotations - -import os +import json import logging +import os import time -import json from typing import Any, Dict, List, Optional, Union, Tuple import requests @@ -36,11 +35,12 @@ from openai.types.chat.chat_completion import Choice from openai.types.completion_usage import CompletionUsage -from autogen.oai.client_utils import should_hide_tools, validate_parameter from autogen.cache import Cache +from autogen.oai.client_utils import should_hide_tools, validate_parameter logger = logging.getLogger(__name__) + class GithubClient: """GitHub LLM Client with Azure Fallback""" @@ -66,7 +66,7 @@ class GithubClient: "phi-3-mini-instruct-128k", "phi-3-mini-instruct-4k", "phi-3-small-instruct-128k", - "phi-3-small-instruct-8k" + "phi-3-small-instruct-8k", ] def __init__(self, **kwargs): @@ -97,22 +97,15 @@ def message_retrieval(self, response: ChatCompletion) -> List[str]: def create(self, params: Dict[str, Any]) -> ChatCompletion: """Create a completion for a given config.""" messages = params.get("messages", []) - + if "system" not in [m["role"] for m in messages]: messages.insert(0, {"role": "system", "content": self.system_prompt}) - data = { - "messages": messages, - "model": self.model, - **params - } + data = {"messages": messages, "model": self.model, **params} if self._check_rate_limit(): try: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.github_token}" - } + headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.github_token}"} response = self._call_api(self.github_endpoint_url, headers, data) self._increment_request_count() @@ -121,10 +114,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: logger.warning(f"GitHub API call failed: {str(e)}. Falling back to Azure.") if self.use_azure_fallback: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.azure_api_key}" - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.azure_api_key}"} response = self._call_api(self.github_endpoint_url, headers, data) return self._process_response(response) @@ -157,11 +147,8 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: choices = [ Choice( index=i, - message=ChatCompletionMessage( - role="assistant", - content=choice["message"]["content"] - ), - finish_reason=choice.get("finish_reason") + message=ChatCompletionMessage(role="assistant", content=choice["message"]["content"]), + finish_reason=choice.get("finish_reason"), ) for i, choice in enumerate(response_data["choices"]) ] @@ -169,7 +156,7 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: usage = CompletionUsage( prompt_tokens=response_data["usage"]["prompt_tokens"], completion_tokens=response_data["usage"]["completion_tokens"], - total_tokens=response_data["usage"]["total_tokens"] + total_tokens=response_data["usage"]["total_tokens"], ) return ChatCompletion( @@ -178,7 +165,7 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion: created=response_data["created"], object="chat.completion", choices=choices, - usage=usage + usage=usage, ) def cost(self, response: ChatCompletion) -> float: diff --git a/test/oai/test_githubllm.py b/test/oai/test_githubllm.py index 0a3b879f1dcd..c182fc6d4bb8 100644 --- a/test/oai/test_githubllm.py +++ b/test/oai/test_githubllm.py @@ -1,33 +1,34 @@ -import pytest from unittest.mock import patch, MagicMock +import pytest + from autogen.oai.github import GithubClient, GithubWrapper @pytest.fixture def github_client(): - with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): return GithubClient(model="gpt-4o", system_prompt="Test prompt") + @pytest.fixture def github_wrapper(): - with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): - config = { - "model": "gpt-4o", - "system_prompt": "Test prompt", - "use_azure_fallback": True - } + with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): + config = { "model": "gpt-4o", "system_prompt": "Test prompt", "use_azure_fallback": True} return GithubWrapper(config_list=[config]) + def test_github_client_initialization(github_client): assert github_client.model == "gpt-4o" assert github_client.system_prompt == "Test prompt" assert github_client.use_azure_fallback == True + def test_github_client_unsupported_model(): with pytest.raises(ValueError): - with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}): + with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}): GithubClient(model="unsupported-model") -@patch('requests.post') + +@patch("requests.post") def test_github_client_create(mock_post, github_client): mock_response = MagicMock() mock_response.json.return_value = { @@ -35,7 +36,7 @@ def test_github_client_create(mock_post, github_client): "model": "gpt-4o", "created": 1234567890, "choices": [{"message": {"content": "Test response"}, "finish_reason": "stop"}], - "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, } mock_post.return_value = mock_response @@ -47,21 +48,24 @@ def test_github_client_create(mock_post, github_client): assert len(response.choices) == 1 assert response.choices[0].message.content == "Test response" + def test_github_client_message_retrieval(github_client): mock_response = MagicMock() mock_response.choices = [ MagicMock(message=MagicMock(content="Response 1")), - MagicMock(message=MagicMock(content="Response 2")) + MagicMock(message=MagicMock(content="Response 2")), ] - + messages = github_client.message_retrieval(mock_response) assert messages == ["Response 1", "Response 2"] + def test_github_client_cost(github_client): mock_response = MagicMock() cost = github_client.cost(mock_response) assert cost == 0.0 # Assuming the placeholder implementation + def test_github_client_get_usage(github_client): mock_response = MagicMock() mock_response.usage.prompt_tokens = 10 @@ -75,7 +79,8 @@ def test_github_client_get_usage(github_client): assert usage["total_tokens"] == 30 assert usage["model"] == "gpt-4o" -@patch('autogen.oai.github.GithubClient.create') + +@patch("autogen.oai.github.GithubClient.create") def test_github_wrapper_create(mock_create, github_wrapper): mock_response = MagicMock() mock_create.return_value = mock_response @@ -84,29 +89,31 @@ def test_github_wrapper_create(mock_create, github_wrapper): response = github_wrapper.create(**params) assert response == mock_response - assert hasattr(response, 'config_id') + assert hasattr(response, "config_id") mock_create.assert_called_once_with(params) def test_github_wrapper_message_retrieval(github_wrapper): mock_response = MagicMock() mock_response.config_id = 0 - - with patch.object(github_wrapper._clients[0], 'message_retrieval') as mock_retrieval: + + + with patch.object(github_wrapper._clients[0], "message_retrieval") as mock_retrieval: mock_retrieval.return_value = ["Test message"] messages = github_wrapper.message_retrieval(mock_response) - + assert messages == ["Test message"] def test_github_wrapper_cost(github_wrapper): mock_response = MagicMock() mock_response.config_id = 0 - - with patch.object(github_wrapper._clients[0], 'cost') as mock_cost: + + with patch.object(github_wrapper._clients[0], "cost") as mock_cost: mock_cost.return_value = 0.05 cost = github_wrapper.cost(mock_response) - + assert cost == 0.05 + def test_github_wrapper_get_usage(github_wrapper): mock_response = MagicMock() mock_response.usage.prompt_tokens = 10