From d19bf74cd6bdbd7339344d37593d61141cd64135 Mon Sep 17 00:00:00 2001 From: Reuben Thomas-Davis Date: Tue, 5 Sep 2023 04:03:25 +0100 Subject: [PATCH] Support caching of async completion and cache completion (#513) * Use the old version for the chromadb (#492) Signed-off-by: SimFG Signed-off-by: Reuben Thomas-Davis * added support for weaviate vector databse (#493) * added support for weaviate vector databse Signed-off-by: pranaychandekar * added support for in local db for weaviate vector store Signed-off-by: pranaychandekar * added unit test case for weaviate vector store Signed-off-by: pranaychandekar * resolved unit test case error for weaviate vector store Signed-off-by: pranaychandekar * increased code coverage resolved pylint issues pylint: disabled C0413 Signed-off-by: pranaychandekar --------- Signed-off-by: pranaychandekar Signed-off-by: Reuben Thomas-Davis * Update the version to `0.1.37` (#494) Signed-off-by: SimFG Signed-off-by: Reuben Thomas-Davis * :sparkles: support caching of async completion and cache completion Signed-off-by: Reuben Thomas-Davis * :sparkles: add streaming support for chatcompletion Signed-off-by: Reuben Thomas-Davis * :white_check_mark: improve test coverage and formatting Signed-off-by: Reuben Thomas-Davis * :sparkles: support caching of async completion and cache completion Signed-off-by: Reuben Thomas-Davis * :sparkles: add streaming support for chatcompletion Signed-off-by: Reuben Thomas-Davis * :white_check_mark: improve test coverage and formatting Signed-off-by: Reuben Thomas-Davis * correct merge duplication Signed-off-by: Reuben Thomas-Davis * correct update cache callback Signed-off-by: Reuben Thomas-Davis * add additional tests for improved coverage Signed-off-by: Reuben Thomas-Davis * remove redundant param in docstring Signed-off-by: Reuben Thomas-Davis --------- Signed-off-by: SimFG Signed-off-by: Reuben Thomas-Davis Signed-off-by: pranaychandekar Co-authored-by: SimFG Co-authored-by: Pranay Chandekar --- gptcache/adapter/adapter.py | 1 - gptcache/adapter/openai.py | 109 +++++++- tests/unit_tests/adapter/test_openai.py | 327 ++++++++++++++++++++++-- 3 files changed, 403 insertions(+), 34 deletions(-) diff --git a/gptcache/adapter/adapter.py b/gptcache/adapter/adapter.py index e167ec24..0e4fe490 100644 --- a/gptcache/adapter/adapter.py +++ b/gptcache/adapter/adapter.py @@ -513,7 +513,6 @@ def update_cache_func(handled_llm_data, question=None): == 0 ): chat_cache.flush() - llm_data = update_cache_callback( llm_data, update_cache_func, *args, **kwargs ) diff --git a/gptcache/adapter/openai.py b/gptcache/adapter/openai.py index 6f3d50aa..ddc3ca89 100644 --- a/gptcache/adapter/openai.py +++ b/gptcache/adapter/openai.py @@ -3,21 +3,21 @@ import os import time from io import BytesIO -from typing import Iterator, Any, List +from typing import Any, AsyncGenerator, Iterator, List from gptcache import cache -from gptcache.adapter.adapter import adapt +from gptcache.adapter.adapter import aadapt, adapt from gptcache.adapter.base import BaseCacheLLM from gptcache.manager.scalar_data.base import Answer, DataType from gptcache.utils import import_openai, import_pillow from gptcache.utils.error import wrap_error from gptcache.utils.response import ( - get_stream_message_from_openai_answer, - get_message_from_openai_answer, - get_text_from_openai_answer, + get_audio_text_from_openai_answer, get_image_from_openai_b64, get_image_from_openai_url, - get_audio_text_from_openai_answer, + get_message_from_openai_answer, + get_stream_message_from_openai_answer, + get_text_from_openai_answer, ) from gptcache.utils.token import token_counter @@ -56,7 +56,22 @@ class ChatCompletion(openai.ChatCompletion, BaseCacheLLM): @classmethod def _llm_handler(cls, *llm_args, **llm_kwargs): try: - return super().create(*llm_args, **llm_kwargs) if cls.llm is None else cls.llm(*llm_args, **llm_kwargs) + return ( + super().create(*llm_args, **llm_kwargs) + if cls.llm is None + else cls.llm(*llm_args, **llm_kwargs) + ) + except openai.OpenAIError as e: + raise wrap_error(e) from e + + @classmethod + async def _allm_handler(cls, *llm_args, **llm_kwargs): + try: + return ( + (await super().acreate(*llm_args, **llm_kwargs)) + if cls.llm is None + else await cls.llm(*llm_args, **llm_kwargs) + ) except openai.OpenAIError as e: raise wrap_error(e) from e @@ -64,7 +79,17 @@ def _llm_handler(cls, *llm_args, **llm_kwargs): def _update_cache_callback( llm_data, update_cache_func, *args, **kwargs ): # pylint: disable=unused-argument - if not isinstance(llm_data, Iterator): + if isinstance(llm_data, AsyncGenerator): + + async def hook_openai_data(it): + total_answer = "" + async for item in it: + total_answer += get_stream_message_from_openai_answer(item) + yield item + update_cache_func(Answer(total_answer, DataType.STR)) + + return hook_openai_data(llm_data) + elif not isinstance(llm_data, Iterator): update_cache_func( Answer(get_message_from_openai_answer(llm_data), DataType.STR) ) @@ -92,8 +117,6 @@ def cache_data_convert(cache_data): saved_token = [input_token, output_token] else: saved_token = [0, 0] - if kwargs.get("stream", False): - return _construct_stream_resp_from_cache(cache_data, saved_token) return _construct_resp_from_cache(cache_data, saved_token) kwargs = cls.fill_base_args(**kwargs) @@ -105,6 +128,38 @@ def cache_data_convert(cache_data): **kwargs, ) + @classmethod + async def acreate(cls, *args, **kwargs): + chat_cache = kwargs.get("cache_obj", cache) + enable_token_counter = chat_cache.config.enable_token_counter + + def cache_data_convert(cache_data): + if enable_token_counter: + input_token = _num_tokens_from_messages(kwargs.get("messages")) + output_token = token_counter(cache_data) + saved_token = [input_token, output_token] + else: + saved_token = [0, 0] + if kwargs.get("stream", False): + return async_iter( + _construct_stream_resp_from_cache(cache_data, saved_token) + ) + return _construct_resp_from_cache(cache_data, saved_token) + + kwargs = cls.fill_base_args(**kwargs) + return await aadapt( + cls._allm_handler, + cache_data_convert, + cls._update_cache_callback, + *args, + **kwargs, + ) + + +async def async_iter(input_list): + for item in input_list: + yield item + class Completion(openai.Completion, BaseCacheLLM): """Openai Completion Wrapper @@ -128,7 +183,22 @@ class Completion(openai.Completion, BaseCacheLLM): @classmethod def _llm_handler(cls, *llm_args, **llm_kwargs): try: - return super().create(*llm_args, **llm_kwargs) if not cls.llm else cls.llm(*llm_args, **llm_kwargs) + return ( + super().create(*llm_args, **llm_kwargs) + if not cls.llm + else cls.llm(*llm_args, **llm_kwargs) + ) + except openai.OpenAIError as e: + raise wrap_error(e) from e + + @classmethod + async def _allm_handler(cls, *llm_args, **llm_kwargs): + try: + return ( + (await super().acreate(*llm_args, **llm_kwargs)) + if cls.llm is None + else await cls.llm(*llm_args, **llm_kwargs) + ) except openai.OpenAIError as e: raise wrap_error(e) from e @@ -154,6 +224,17 @@ def create(cls, *args, **kwargs): **kwargs, ) + @classmethod + async def acreate(cls, *args, **kwargs): + kwargs = cls.fill_base_args(**kwargs) + return await aadapt( + cls._allm_handler, + cls._cache_data_convert, + cls._update_cache_callback, + *args, + **kwargs, + ) + class Audio(openai.Audio): """Openai Audio Wrapper @@ -319,7 +400,11 @@ class Moderation(openai.Moderation, BaseCacheLLM): @classmethod def _llm_handler(cls, *llm_args, **llm_kwargs): try: - return super().create(*llm_args, **llm_kwargs) if not cls.llm else cls.llm(*llm_args, **llm_kwargs) + return ( + super().create(*llm_args, **llm_kwargs) + if not cls.llm + else cls.llm(*llm_args, **llm_kwargs) + ) except openai.OpenAIError as e: raise wrap_error(e) from e diff --git a/tests/unit_tests/adapter/test_openai.py b/tests/unit_tests/adapter/test_openai.py index 9d266687..d6762899 100644 --- a/tests/unit_tests/adapter/test_openai.py +++ b/tests/unit_tests/adapter/test_openai.py @@ -1,28 +1,34 @@ +import asyncio import base64 import os import random from io import BytesIO -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from urllib.request import urlopen -from gptcache import cache, Cache +import pytest + +from gptcache import Cache, cache from gptcache.adapter import openai from gptcache.adapter.api import init_similar_cache +from gptcache.config import Config from gptcache.manager import get_data_manager from gptcache.processor.pre import ( - get_prompt, - get_file_name, get_file_bytes, - get_openai_moderation_input, last_content, + get_file_name, + get_openai_moderation_input, + get_prompt, + last_content, ) +from gptcache.utils.error import CacheError from gptcache.utils.response import ( - get_stream_message_from_openai_answer, - get_message_from_openai_answer, - get_text_from_openai_answer, + get_audio_text_from_openai_answer, get_image_from_openai_b64, - get_image_from_path, get_image_from_openai_url, - get_audio_text_from_openai_answer, + get_image_from_path, + get_message_from_openai_answer, + get_stream_message_from_openai_answer, + get_text_from_openai_answer, ) try: @@ -34,8 +40,9 @@ from PIL import Image -def test_normal_openai(): - cache.init() +@pytest.mark.parametrize("enable_token_counter", (True, False)) +def test_normal_openai(enable_token_counter): + cache.init(config=Config(enable_token_counter=enable_token_counter)) question = "calculate 1+3" expect_answer = "the result is 4" with patch("openai.ChatCompletion.create") as mock_create: @@ -75,6 +82,53 @@ def test_normal_openai(): assert answer_text == expect_answer, answer_text +@pytest.mark.asyncio +@pytest.mark.parametrize("enable_token_counter", (True, False)) +async def test_normal_openai_async(enable_token_counter): + cache.init(config=Config(enable_token_counter=enable_token_counter)) + question = "calculate 1+3" + expect_answer = "the result is 4" + import openai as real_openai + + with patch.object( + real_openai.ChatCompletion, "acreate", new_callable=AsyncMock + ) as mock_acreate: + datas = { + "choices": [ + { + "message": {"content": expect_answer, "role": "assistant"}, + "finish_reason": "stop", + "index": 0, + } + ], + "created": 1677825464, + "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", + "model": "gpt-3.5-turbo-0301", + "object": "chat.completion.chunk", + } + mock_acreate.return_value = datas + + response = await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question}, + ], + ) + + assert get_message_from_openai_answer(response) == expect_answer, response + + response = await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question}, + ], + ) + answer_text = get_message_from_openai_answer(response) + assert answer_text == expect_answer, answer_text + + def test_stream_openai(): cache.init() question = "calculate 1+1" @@ -148,6 +202,91 @@ def test_stream_openai(): assert answer_text == expect_answer, answer_text +@pytest.mark.asyncio +async def test_stream_openai_async(): + cache.init() + question = "calculate 1+4" + expect_answer = "the result is 5" + import openai as real_openai + + with patch.object( + real_openai.ChatCompletion, "acreate", new_callable=AsyncMock + ) as mock_acreate: + datas = [ + { + "choices": [ + {"delta": {"role": "assistant"}, "finish_reason": None, "index": 0} + ], + "created": 1677825464, + "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", + "model": "gpt-3.5-turbo-0301", + "object": "chat.completion.chunk", + }, + { + "choices": [ + { + "delta": {"content": "the result"}, + "finish_reason": None, + "index": 0, + } + ], + "created": 1677825464, + "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", + "model": "gpt-3.5-turbo-0301", + "object": "chat.completion.chunk", + }, + { + "choices": [ + {"delta": {"content": " is 5"}, "finish_reason": None, "index": 0} + ], + "created": 1677825464, + "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", + "model": "gpt-3.5-turbo-0301", + "object": "chat.completion.chunk", + }, + { + "choices": [{"delta": {}, "finish_reason": "stop", "index": 0}], + "created": 1677825464, + "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", + "model": "gpt-3.5-turbo-0301", + "object": "chat.completion.chunk", + }, + ] + + async def acreate(*args, **kwargs): + for item in datas: + yield item + await asyncio.sleep(0) + + mock_acreate.return_value = acreate() + + response = await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question}, + ], + stream=True, + ) + all_text = "" + async for res in response: + all_text += get_stream_message_from_openai_answer(res) + assert all_text == expect_answer, all_text + + response = await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question}, + ], + stream=True, + ) + answer_text = "" + async for res in response: + answer_text += get_stream_message_from_openai_answer(res) + assert answer_text == expect_answer, answer_text + + def test_completion(): cache.init(pre_embedding_func=get_prompt) question = "what is your name?" @@ -171,6 +310,52 @@ def test_completion(): assert answer_text == expect_answer +@pytest.mark.asyncio +async def test_completion_async(): + cache.init(pre_embedding_func=get_prompt) + question = "what is your name?" + expect_answer = "gptcache" + + with patch("openai.Completion.acreate", new_callable=AsyncMock) as mock_acreate: + mock_acreate.return_value = { + "choices": [{"text": expect_answer, "finish_reason": None, "index": 0}], + "created": 1677825464, + "id": "cmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", + "model": "text-davinci-003", + "object": "text_completion", + } + + response = await openai.Completion.acreate( + model="text-davinci-003", prompt=question + ) + answer_text = get_text_from_openai_answer(response) + assert answer_text == expect_answer + + response = await openai.Completion.acreate( + model="text-davinci-003", prompt=question + ) + answer_text = get_text_from_openai_answer(response) + assert answer_text == expect_answer + + +@pytest.mark.asyncio +async def test_completion_error_wrapping(): + cache.init(pre_embedding_func=get_prompt) + import openai as real_openai + + with patch("openai.Completion.acreate", new_callable=AsyncMock) as mock_acreate: + mock_acreate.side_effect = real_openai.OpenAIError + with pytest.raises(real_openai.OpenAIError) as e: + await openai.Completion.acreate(model="text-davinci-003", prompt="boom") + assert isinstance(e.value, CacheError) + + with patch("openai.Completion.create") as mock_create: + mock_create.side_effect = real_openai.OpenAIError + with pytest.raises(real_openai.OpenAIError) as e: + openai.Completion.create(model="text-davinci-003", prompt="boom") + assert isinstance(e.value, CacheError) + + def test_image_create(): cache.init(pre_embedding_func=get_prompt) prompt1 = "test url" # bytes @@ -312,16 +497,16 @@ def test_moderation(): input=["I want to kill them."], ) assert ( - response.get("results")[0].get("category_scores").get("violence") - == expect_violence + response.get("results")[0].get("category_scores").get("violence") + == expect_violence ) response = openai.Moderation.create( input="I want to kill them.", ) assert ( - response.get("results")[0].get("category_scores").get("violence") - == expect_violence + response.get("results")[0].get("category_scores").get("violence") + == expect_violence ) expect_violence = 0.88708615 @@ -379,8 +564,8 @@ def test_moderation(): ) assert not response.get("results")[0].get("flagged") assert ( - response.get("results")[1].get("category_scores").get("violence") - == expect_violence + response.get("results")[1].get("category_scores").get("violence") + == expect_violence ) response = openai.Moderation.create( @@ -388,8 +573,8 @@ def test_moderation(): ) assert not response.get("results")[0].get("flagged") assert ( - response.get("results")[1].get("category_scores").get("violence") - == expect_violence + response.get("results")[1].get("category_scores").get("violence") + == expect_violence ) @@ -462,7 +647,7 @@ def proxy_openai_chat_complete(*args, **kwargs): is_exception = False try: - openai.ChatCompletion.create( + resp = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -489,6 +674,106 @@ def proxy_openai_chat_complete(*args, **kwargs): openai.ChatCompletion.cache_args = {} assert get_message_from_openai_answer(response) == expect_answer, response + +@pytest.mark.asyncio +async def test_base_llm_cache_async(): + cache_obj = Cache() + init_similar_cache( + data_dir=str(random.random()), pre_func=last_content, cache_obj=cache_obj + ) + question = "What's Github" + expect_answer = "Github is a great place to start" + import openai as real_openai + + with patch.object( + real_openai.ChatCompletion, "acreate", new_callable=AsyncMock + ) as mock_acreate: + datas = { + "choices": [ + { + "message": {"content": expect_answer, "role": "assistant"}, + "finish_reason": "stop", + "index": 0, + } + ], + "created": 1677825464, + "id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD", + "model": "gpt-3.5-turbo-0301", + "object": "chat.completion.chunk", + } + mock_acreate.return_value = datas + + async def proxy_openai_chat_complete_exception(*args, **kwargs): + raise real_openai.error.APIConnectionError("connect fail") + + openai.ChatCompletion.llm = proxy_openai_chat_complete_exception + + is_openai_exception = False + try: + await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question}, + ], + cache_obj=cache_obj, + ) + except real_openai.error.APIConnectionError: + is_openai_exception = True + + assert is_openai_exception + + is_proxy = False + + def proxy_openai_chat_complete(*args, **kwargs): + nonlocal is_proxy + is_proxy = True + return real_openai.ChatCompletion.acreate(*args, **kwargs) + + openai.ChatCompletion.llm = proxy_openai_chat_complete + + response = await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question}, + ], + cache_obj=cache_obj, + ) + assert is_proxy + + assert get_message_from_openai_answer(response) == expect_answer, response + + is_exception = False + try: + resp = await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question}, + ], + ) + except Exception: + is_exception = True + assert is_exception + + openai.ChatCompletion.cache_args = {"cache_obj": cache_obj} + + print(openai.ChatCompletion.fill_base_args(foo="hello")) + + response = await openai.ChatCompletion.acreate( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question}, + ], + ) + + openai.ChatCompletion.llm = None + openai.ChatCompletion.cache_args = {} + assert get_message_from_openai_answer(response) == expect_answer, response + + # def test_audio_api(): # data2vec = Data2VecAudio() # data_manager = manager_factory("sqlite,faiss,local", "audio_api", vector_params={"dimension": data2vec.dimension})