Skip to content

Commit

Permalink
(feat) init redis cache with **kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff committed Dec 5, 2023
1 parent 31f3187 commit 9ba1765
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
10 changes: 6 additions & 4 deletions litellm/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def flush_cache(self):


class RedisCache(BaseCache):
def __init__(self, host, port, password):
def __init__(self, host, port, password, **kwargs):
import redis
# if users don't provider one, use the default litellm cache
self.redis_client = redis.Redis(host=host, port=port, password=password)
self.redis_client = redis.Redis(host=host, port=port, password=password, **kwargs)

def set_cache(self, key, value, **kwargs):
ttl = kwargs.get("ttl", None)
Expand Down Expand Up @@ -168,7 +168,8 @@ def __init__(
type="local",
host=None,
port=None,
password=None
password=None,
**kwargs
):
"""
Initializes the cache based on the given type.
Expand All @@ -178,6 +179,7 @@ def __init__(
host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis".
**kwargs: Additional keyword arguments for redis.Redis() cache
Raises:
ValueError: If an invalid cache type is provided.
Expand All @@ -186,7 +188,7 @@ def __init__(
None
"""
if type == "redis":
self.cache = RedisCache(host, port, password)
self.cache = RedisCache(host, port, password, **kwargs)
if type == "local":
self.cache = InMemoryCache()
if "cache" not in litellm.input_callback:
Expand Down
27 changes: 25 additions & 2 deletions litellm/tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_embedding_caching():
print(f"embedding2: {embedding2}")
pytest.fail("Error occurred: Embedding caching failed")

test_embedding_caching()
# test_embedding_caching()


def test_embedding_caching_azure():
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_redis_cache_completion():
print(f"response4: {response4}")
pytest.fail(f"Error occurred:")

test_redis_cache_completion()
# test_redis_cache_completion()

# redis cache with custom keys
def custom_get_cache_key(*args, **kwargs):
Expand Down Expand Up @@ -231,6 +231,29 @@ def get_cache(key):

# test_custom_redis_cache_with_key()


def test_custom_redis_cache_params():
# test if we can init redis with **kwargs
try:
litellm.cache = Cache(
type="redis",
host=os.environ['REDIS_HOST'],
port=os.environ['REDIS_PORT'],
password=os.environ['REDIS_PASSWORD'],
db = 0,
ssl=True,
ssl_certfile="./redis_user.crt",
ssl_keyfile="./redis_user_private.key",
ssl_ca_certs="./redis_ca.pem",
)

print(litellm.cache.cache.redis_client)
litellm.cache = None
except Exception as e:
pytest.fail(f"Error occurred:", e)

# test_custom_redis_cache_params()

# def test_redis_cache_with_ttl():
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
# sample_model_response_object_str = """{
Expand Down

0 comments on commit 9ba1765

Please sign in to comment.