Skip to content

Commit

Permalink
fix some kind of linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Josephrp committed Aug 6, 2024
1 parent 1448b9c commit 1e0971c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 52 deletions.
49 changes: 18 additions & 31 deletions autogen/oai/github.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'''Create a Github LLM Client with Azure Fallback.
"""Create a Github LLM Client with Azure Fallback.
# Usage example:
if __name__ == "__main__":
Expand All @@ -7,9 +7,9 @@
"system_prompt": "You are a knowledgeable history teacher.",
"use_azure_fallback": True
}
wrapper = GithubWrapper(config_list=[config])
response = wrapper.create(messages=[{"role": "user", "content": "What is the capital of France?"}])
print(wrapper.message_retrieval(response)[0])
Expand All @@ -18,29 +18,29 @@
{"role": "assistant", "content": "The French Revolution was a period of major social and political upheaval in France that began in 1789 with the Storming of the Bastille and ended in the late 1790s with the ascent of Napoleon Bonaparte."},
{"role": "user", "content": "What were the main causes?"}
]
response = wrapper.create(messages=conversation)
print(wrapper.message_retrieval(response)[0])
'''
"""

from __future__ import annotations

import os
import json
import logging
import os
import time
import json
from typing import Any, Dict, List, Optional, Union, Tuple

import requests
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.completion_usage import CompletionUsage

from autogen.oai.client_utils import should_hide_tools, validate_parameter
from autogen.cache import Cache
from autogen.oai.client_utils import should_hide_tools, validate_parameter

logger = logging.getLogger(__name__)


class GithubClient:
"""GitHub LLM Client with Azure Fallback"""

Expand All @@ -66,7 +66,7 @@ class GithubClient:
"phi-3-mini-instruct-128k",
"phi-3-mini-instruct-4k",
"phi-3-small-instruct-128k",
"phi-3-small-instruct-8k"
"phi-3-small-instruct-8k",
]

def __init__(self, **kwargs):
Expand Down Expand Up @@ -97,22 +97,15 @@ def message_retrieval(self, response: ChatCompletion) -> List[str]:
def create(self, params: Dict[str, Any]) -> ChatCompletion:
"""Create a completion for a given config."""
messages = params.get("messages", [])

if "system" not in [m["role"] for m in messages]:
messages.insert(0, {"role": "system", "content": self.system_prompt})

data = {
"messages": messages,
"model": self.model,
**params
}
data = {"messages": messages, "model": self.model, **params}

if self._check_rate_limit():
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.github_token}"
}
headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.github_token}"}

response = self._call_api(self.github_endpoint_url, headers, data)
self._increment_request_count()
Expand All @@ -121,10 +114,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
logger.warning(f"GitHub API call failed: {str(e)}. Falling back to Azure.")

if self.use_azure_fallback:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.azure_api_key}"
}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.azure_api_key}"}

response = self._call_api(self.github_endpoint_url, headers, data)
return self._process_response(response)
Expand Down Expand Up @@ -157,19 +147,16 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion:
choices = [
Choice(
index=i,
message=ChatCompletionMessage(
role="assistant",
content=choice["message"]["content"]
),
finish_reason=choice.get("finish_reason")
message=ChatCompletionMessage(role="assistant", content=choice["message"]["content"]),
finish_reason=choice.get("finish_reason"),
)
for i, choice in enumerate(response_data["choices"])
]

usage = CompletionUsage(
prompt_tokens=response_data["usage"]["prompt_tokens"],
completion_tokens=response_data["usage"]["completion_tokens"],
total_tokens=response_data["usage"]["total_tokens"]
total_tokens=response_data["usage"]["total_tokens"],
)

return ChatCompletion(
Expand All @@ -178,7 +165,7 @@ def _process_response(self, response_data: Dict[str, Any]) -> ChatCompletion:
created=response_data["created"],
object="chat.completion",
choices=choices,
usage=usage
usage=usage,
)

def cost(self, response: ChatCompletion) -> float:
Expand Down
49 changes: 28 additions & 21 deletions test/oai/test_githubllm.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,42 @@
import pytest
from unittest.mock import patch, MagicMock
import pytest

from autogen.oai.github import GithubClient, GithubWrapper

@pytest.fixture
def github_client():
with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}):
with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}):
return GithubClient(model="gpt-4o", system_prompt="Test prompt")


@pytest.fixture
def github_wrapper():
with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}):
config = {
"model": "gpt-4o",
"system_prompt": "Test prompt",
"use_azure_fallback": True
}
with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}):
config = { "model": "gpt-4o", "system_prompt": "Test prompt", "use_azure_fallback": True}
return GithubWrapper(config_list=[config])


def test_github_client_initialization(github_client):
assert github_client.model == "gpt-4o"
assert github_client.system_prompt == "Test prompt"
assert github_client.use_azure_fallback == True


def test_github_client_unsupported_model():
with pytest.raises(ValueError):
with patch.dict('os.environ', {'GITHUB_TOKEN': 'fake_token', 'AZURE_API_KEY': 'fake_azure_key'}):
with patch.dict("os.environ", {"GITHUB_TOKEN": "fake_token", "AZURE_API_KEY": "fake_azure_key"}):
GithubClient(model="unsupported-model")

@patch('requests.post')

@patch("requests.post")
def test_github_client_create(mock_post, github_client):
mock_response = MagicMock()
mock_response.json.return_value = {
"id": "test_id",
"model": "gpt-4o",
"created": 1234567890,
"choices": [{"message": {"content": "Test response"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
mock_post.return_value = mock_response

Expand All @@ -47,21 +48,24 @@ def test_github_client_create(mock_post, github_client):
assert len(response.choices) == 1
assert response.choices[0].message.content == "Test response"


def test_github_client_message_retrieval(github_client):
mock_response = MagicMock()
mock_response.choices = [
MagicMock(message=MagicMock(content="Response 1")),
MagicMock(message=MagicMock(content="Response 2"))
MagicMock(message=MagicMock(content="Response 2")),
]

messages = github_client.message_retrieval(mock_response)
assert messages == ["Response 1", "Response 2"]


def test_github_client_cost(github_client):
mock_response = MagicMock()
cost = github_client.cost(mock_response)
assert cost == 0.0 # Assuming the placeholder implementation


def test_github_client_get_usage(github_client):
mock_response = MagicMock()
mock_response.usage.prompt_tokens = 10
Expand All @@ -75,7 +79,8 @@ def test_github_client_get_usage(github_client):
assert usage["total_tokens"] == 30
assert usage["model"] == "gpt-4o"

@patch('autogen.oai.github.GithubClient.create')

@patch("autogen.oai.github.GithubClient.create")
def test_github_wrapper_create(mock_create, github_wrapper):
mock_response = MagicMock()
mock_create.return_value = mock_response
Expand All @@ -84,29 +89,31 @@ def test_github_wrapper_create(mock_create, github_wrapper):
response = github_wrapper.create(**params)

assert response == mock_response
assert hasattr(response, 'config_id')
assert hasattr(response, "config_id")
mock_create.assert_called_once_with(params)

def test_github_wrapper_message_retrieval(github_wrapper):
mock_response = MagicMock()
mock_response.config_id = 0

with patch.object(github_wrapper._clients[0], 'message_retrieval') as mock_retrieval:


with patch.object(github_wrapper._clients[0], "message_retrieval") as mock_retrieval:
mock_retrieval.return_value = ["Test message"]
messages = github_wrapper.message_retrieval(mock_response)

assert messages == ["Test message"]

def test_github_wrapper_cost(github_wrapper):
mock_response = MagicMock()
mock_response.config_id = 0
with patch.object(github_wrapper._clients[0], 'cost') as mock_cost:

with patch.object(github_wrapper._clients[0], "cost") as mock_cost:
mock_cost.return_value = 0.05
cost = github_wrapper.cost(mock_response)

assert cost == 0.05


def test_github_wrapper_get_usage(github_wrapper):
mock_response = MagicMock()
mock_response.usage.prompt_tokens = 10
Expand Down

0 comments on commit 1e0971c

Please sign in to comment.