Skip to content

Commit

Permalink
Support caching of async completion and cache completion (#513)
Browse files Browse the repository at this point in the history
* Use the old version for the chromadb (#492)

Signed-off-by: SimFG <[email protected]>
Signed-off-by: Reuben Thomas-Davis <[email protected]>

* added support for weaviate vector databse (#493)

* added support for weaviate vector databse

Signed-off-by: pranaychandekar <[email protected]>

* added support for in local db for weaviate vector store

Signed-off-by: pranaychandekar <[email protected]>

* added unit test case for weaviate vector store

Signed-off-by: pranaychandekar <[email protected]>

* resolved unit test case error for weaviate vector store

Signed-off-by: pranaychandekar <[email protected]>

* increased code coverage
resolved pylint issues

pylint: disabled C0413

Signed-off-by: pranaychandekar <[email protected]>

---------

Signed-off-by: pranaychandekar <[email protected]>
Signed-off-by: Reuben Thomas-Davis <[email protected]>

* Update the version to `0.1.37` (#494)

Signed-off-by: SimFG <[email protected]>
Signed-off-by: Reuben Thomas-Davis <[email protected]>

* ✨ support caching of async completion and cache completion

Signed-off-by: Reuben Thomas-Davis <[email protected]>

* ✨ add streaming support for chatcompletion

Signed-off-by: Reuben Thomas-Davis <[email protected]>

* ✅ improve test coverage and formatting

Signed-off-by: Reuben Thomas-Davis <[email protected]>

* ✨ support caching of async completion and cache completion

Signed-off-by: Reuben Thomas-Davis <[email protected]>

* ✨ add streaming support for chatcompletion

Signed-off-by: Reuben Thomas-Davis <[email protected]>

* ✅ improve test coverage and formatting

Signed-off-by: Reuben Thomas-Davis <[email protected]>

* correct merge duplication

Signed-off-by: Reuben Thomas-Davis <[email protected]>

* correct update cache callback

Signed-off-by: Reuben Thomas-Davis <[email protected]>

* add additional tests for improved coverage

Signed-off-by: Reuben Thomas-Davis <[email protected]>

* remove redundant param in docstring

Signed-off-by: Reuben Thomas-Davis <[email protected]>

---------

Signed-off-by: SimFG <[email protected]>
Signed-off-by: Reuben Thomas-Davis <[email protected]>
Signed-off-by: pranaychandekar <[email protected]>
Co-authored-by: SimFG <[email protected]>
Co-authored-by: Pranay Chandekar <[email protected]>
  • Loading branch information
3 people authored Sep 5, 2023
1 parent bca8de9 commit d19bf74
Show file tree
Hide file tree
Showing 3 changed files with 403 additions and 34 deletions.
1 change: 0 additions & 1 deletion gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def update_cache_func(handled_llm_data, question=None):
== 0
):
chat_cache.flush()

llm_data = update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
)
Expand Down
109 changes: 97 additions & 12 deletions gptcache/adapter/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
import os
import time
from io import BytesIO
from typing import Iterator, Any, List
from typing import Any, AsyncGenerator, Iterator, List

from gptcache import cache
from gptcache.adapter.adapter import adapt
from gptcache.adapter.adapter import aadapt, adapt
from gptcache.adapter.base import BaseCacheLLM
from gptcache.manager.scalar_data.base import Answer, DataType
from gptcache.utils import import_openai, import_pillow
from gptcache.utils.error import wrap_error
from gptcache.utils.response import (
get_stream_message_from_openai_answer,
get_message_from_openai_answer,
get_text_from_openai_answer,
get_audio_text_from_openai_answer,
get_image_from_openai_b64,
get_image_from_openai_url,
get_audio_text_from_openai_answer,
get_message_from_openai_answer,
get_stream_message_from_openai_answer,
get_text_from_openai_answer,
)
from gptcache.utils.token import token_counter

Expand Down Expand Up @@ -56,15 +56,40 @@ class ChatCompletion(openai.ChatCompletion, BaseCacheLLM):
@classmethod
def _llm_handler(cls, *llm_args, **llm_kwargs):
try:
return super().create(*llm_args, **llm_kwargs) if cls.llm is None else cls.llm(*llm_args, **llm_kwargs)
return (
super().create(*llm_args, **llm_kwargs)
if cls.llm is None
else cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

@classmethod
async def _allm_handler(cls, *llm_args, **llm_kwargs):
try:
return (
(await super().acreate(*llm_args, **llm_kwargs))
if cls.llm is None
else await cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

@staticmethod
def _update_cache_callback(
llm_data, update_cache_func, *args, **kwargs
): # pylint: disable=unused-argument
if not isinstance(llm_data, Iterator):
if isinstance(llm_data, AsyncGenerator):

async def hook_openai_data(it):
total_answer = ""
async for item in it:
total_answer += get_stream_message_from_openai_answer(item)
yield item
update_cache_func(Answer(total_answer, DataType.STR))

return hook_openai_data(llm_data)
elif not isinstance(llm_data, Iterator):
update_cache_func(
Answer(get_message_from_openai_answer(llm_data), DataType.STR)
)
Expand Down Expand Up @@ -92,8 +117,6 @@ def cache_data_convert(cache_data):
saved_token = [input_token, output_token]
else:
saved_token = [0, 0]
if kwargs.get("stream", False):
return _construct_stream_resp_from_cache(cache_data, saved_token)
return _construct_resp_from_cache(cache_data, saved_token)

kwargs = cls.fill_base_args(**kwargs)
Expand All @@ -105,6 +128,38 @@ def cache_data_convert(cache_data):
**kwargs,
)

@classmethod
async def acreate(cls, *args, **kwargs):
chat_cache = kwargs.get("cache_obj", cache)
enable_token_counter = chat_cache.config.enable_token_counter

def cache_data_convert(cache_data):
if enable_token_counter:
input_token = _num_tokens_from_messages(kwargs.get("messages"))
output_token = token_counter(cache_data)
saved_token = [input_token, output_token]
else:
saved_token = [0, 0]
if kwargs.get("stream", False):
return async_iter(
_construct_stream_resp_from_cache(cache_data, saved_token)
)
return _construct_resp_from_cache(cache_data, saved_token)

kwargs = cls.fill_base_args(**kwargs)
return await aadapt(
cls._allm_handler,
cache_data_convert,
cls._update_cache_callback,
*args,
**kwargs,
)


async def async_iter(input_list):
for item in input_list:
yield item


class Completion(openai.Completion, BaseCacheLLM):
"""Openai Completion Wrapper
Expand All @@ -128,7 +183,22 @@ class Completion(openai.Completion, BaseCacheLLM):
@classmethod
def _llm_handler(cls, *llm_args, **llm_kwargs):
try:
return super().create(*llm_args, **llm_kwargs) if not cls.llm else cls.llm(*llm_args, **llm_kwargs)
return (
super().create(*llm_args, **llm_kwargs)
if not cls.llm
else cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

@classmethod
async def _allm_handler(cls, *llm_args, **llm_kwargs):
try:
return (
(await super().acreate(*llm_args, **llm_kwargs))
if cls.llm is None
else await cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

Expand All @@ -154,6 +224,17 @@ def create(cls, *args, **kwargs):
**kwargs,
)

@classmethod
async def acreate(cls, *args, **kwargs):
kwargs = cls.fill_base_args(**kwargs)
return await aadapt(
cls._allm_handler,
cls._cache_data_convert,
cls._update_cache_callback,
*args,
**kwargs,
)


class Audio(openai.Audio):
"""Openai Audio Wrapper
Expand Down Expand Up @@ -319,7 +400,11 @@ class Moderation(openai.Moderation, BaseCacheLLM):
@classmethod
def _llm_handler(cls, *llm_args, **llm_kwargs):
try:
return super().create(*llm_args, **llm_kwargs) if not cls.llm else cls.llm(*llm_args, **llm_kwargs)
return (
super().create(*llm_args, **llm_kwargs)
if not cls.llm
else cls.llm(*llm_args, **llm_kwargs)
)
except openai.OpenAIError as e:
raise wrap_error(e) from e

Expand Down
Loading

0 comments on commit d19bf74

Please sign in to comment.