diff --git a/gptcache/adapter/langchain_models.py b/gptcache/adapter/langchain_models.py index 4113c543..4a3ce9cb 100644 --- a/gptcache/adapter/langchain_models.py +++ b/gptcache/adapter/langchain_models.py @@ -1,10 +1,10 @@ from typing import Optional, List, Any, Mapping from gptcache.adapter.adapter import adapt, aadapt +from gptcache.core import cache from gptcache.manager.scalar_data.base import Answer, DataType from gptcache.session import Session from gptcache.utils import import_pydantic, import_langchain -from gptcache.core import Cache,cache import_pydantic() import_langchain() @@ -51,7 +51,6 @@ class LangChainLLMs(LLM, BaseModel): llm: Any session: Session = None - cache_obj: Cache = cache tmp_args: Any = None @property @@ -76,13 +75,14 @@ def _call( if "session" not in self.tmp_args else self.tmp_args.pop("session") ) + cache_obj = self.tmp_args.pop("cache_obj", cache) return adapt( self.llm, _cache_data_convert, _update_cache_callback, prompt=prompt, stop=stop, - cache_obj=self.cache_obj, + cache_obj=cache_obj, session=session, **self.tmp_args, ) @@ -153,9 +153,8 @@ def _llm_type(self) -> str: return "gptcache_llm_chat" chat: Any - session: Session = None - cache_obj: Cache = cache - tmp_args: Any = None + session: Optional[Session] = None + tmp_args: Optional[Any] = None def _generate( self, @@ -168,13 +167,14 @@ def _generate( if "session" not in self.tmp_args else self.tmp_args.pop("session") ) + cache_obj = self.tmp_args.pop("cache_obj", cache) return adapt( self.chat._generate, _cache_msg_data_convert, _update_cache_msg_callback, messages=messages, stop=stop, - cache_obj=self.cache_obj, + cache_obj=cache_obj, session=session, run_manager=run_manager, **self.tmp_args, @@ -191,14 +191,14 @@ async def _agenerate( if "session" not in self.tmp_args else self.tmp_args.pop("session") ) - + cache_obj = self.tmp_args.pop("cache_obj", cache) return await aadapt( self.chat._agenerate, _cache_msg_data_convert, _update_cache_msg_callback, messages=messages, stop=stop, - cache_obj=self.cache_obj, + cache_obj=cache_obj, session=session, run_manager=run_manager, **self.tmp_args, diff --git a/gptcache/adapter/openai.py b/gptcache/adapter/openai.py index ddc3ca89..9a898ea3 100644 --- a/gptcache/adapter/openai.py +++ b/gptcache/adapter/openai.py @@ -117,6 +117,8 @@ def cache_data_convert(cache_data): saved_token = [input_token, output_token] else: saved_token = [0, 0] + if kwargs.get("stream", False): + return _construct_stream_resp_from_cache(cache_data, saved_token) return _construct_resp_from_cache(cache_data, saved_token) kwargs = cls.fill_base_args(**kwargs) diff --git a/tests/requirements.txt b/tests/requirements.txt index f3c6b42c..fdab2575 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -16,6 +16,7 @@ pytest-sugar==0.9.5 pytest-parallel psycopg2-binary transformers==4.29.2 +anyio==3.6.2 torch mock pexpect diff --git a/tests/unit_tests/test_client.py b/tests/unit_tests/test_client.py index 03a6f7a7..c0c8a576 100644 --- a/tests/unit_tests/test_client.py +++ b/tests/unit_tests/test_client.py @@ -1,5 +1,8 @@ from unittest.mock import patch, Mock +from gptcache.utils import import_httpx + +import_httpx() from gptcache.client import Client