Skip to content

Commit

Permalink
Bugfix/#486: provide redis_connection for creating all Object Models (#…
Browse files Browse the repository at this point in the history
…487)

* [mod] provide redis_connection while model creation for consistency

Signed-off-by: Anurag Wagh <[email protected]>

* [add] documentation for `get_models` method

Signed-off-by: Anurag Wagh <[email protected]>

* [add] set redis connection details to common variable for unit test

Signed-off-by: Anurag Wagh <[email protected]>

* [add] provide redis connection for embedded model

Signed-off-by: Anurag Wagh <[email protected]>

* [add] add doc string for Counter class

Signed-off-by: Anurag Wagh <[email protected]>

---------

Signed-off-by: Anurag Wagh <[email protected]>
  • Loading branch information
a9raag authored Jul 14, 2023
1 parent e811c62 commit a5fc129
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
41 changes: 32 additions & 9 deletions gptcache/manager/scalar_data/redis_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,32 @@
from redis_om import JsonModel, EmbeddedJsonModel, NotFoundError, Field, Migrator


def get_models(global_key):
def get_models(global_key: str, redis_connection: Redis):
"""
Get all the models for the given global key and redis connection.
:param global_key: Global key will be used as a prefix for all the keys
:type global_key: str
:param redis_connection: Redis connection to use for all the models.
Note: This needs to be explicitly mentioned in `Meta` class for each Object Model,
otherwise it will use the default connection from the pool.
:type redis_connection: Redis
"""

class Counter:
"""
counter collection
"""
key_name = global_key + ":counter"
database = redis_connection

@classmethod
def incr(cls, con: Redis):
con.incr(cls.key_name)
def incr(cls):
cls.database.incr(cls.key_name)

@classmethod
def get(cls, con: Redis):
return con.get(cls.key_name)
def get(cls):
return cls.database.get(cls.key_name)

class Embedding:
"""
Expand Down Expand Up @@ -75,6 +90,9 @@ class Answers(EmbeddedJsonModel):
answer: str
answer_type: int

class Meta:
database = redis_connection

class Questions(JsonModel):
"""
questions collection
Expand All @@ -89,6 +107,7 @@ class Questions(JsonModel):
class Meta:
global_key_prefix = global_key
model_key_prefix = "questions"
database = redis_connection

class Sessions(JsonModel):
"""
Expand All @@ -98,6 +117,7 @@ class Sessions(JsonModel):
class Meta:
global_key_prefix = global_key
model_key_prefix = "sessions"
database = redis_connection

session_id: str = Field(index=True)
session_question: str
Expand All @@ -111,6 +131,7 @@ class QuestionDeps(JsonModel):
class Meta:
global_key_prefix = global_key
model_key_prefix = "ques_deps"
database = redis_connection

question_id: str = Field(index=True)
dep_name: str
Expand All @@ -125,6 +146,7 @@ class Report(JsonModel):
class Meta:
global_key_prefix = global_key
model_key_prefix = "report"
database = redis_connection

user_question: str
cache_question_id: int = Field(index=True)
Expand Down Expand Up @@ -194,16 +216,16 @@ def __init__(
self._session,
self._counter,
self._report,
) = get_models(global_key_prefix)
) = get_models(global_key_prefix, redis_connection=self.con)

Migrator().run()

def create(self):
pass

def _insert(self, data: CacheData, pipeline: Pipeline = None):
self._counter.incr(self.con)
pk = str(self._counter.get(self.con))
self._counter.incr()
pk = str(self._counter.get())
answers = data.answers if isinstance(data.answers, list) else [data.answers]
all_data = []
for answer in answers:
Expand Down Expand Up @@ -360,7 +382,8 @@ def delete_session(self, keys: List[str]):
self._session.delete_many(sessions_to_delete, pipeline)
pipeline.execute()

def report_cache(self, user_question, cache_question, cache_question_id, cache_answer, similarity_value, cache_delta_time):
def report_cache(self, user_question, cache_question, cache_question_id, cache_answer, similarity_value,
cache_delta_time):
self._report(
user_question=user_question,
cache_question=cache_question,
Expand Down
19 changes: 12 additions & 7 deletions tests/unit_tests/manager/test_redis_cache_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,30 @@
import numpy as np

from gptcache.manager.scalar_data.base import CacheData, Question
from gptcache.manager.scalar_data.redis_storage import RedisCacheStorage
from gptcache.manager.scalar_data.redis_storage import RedisCacheStorage, get_models
from gptcache.utils import import_redis

import_redis()
from redis_om import get_redis_connection
from redis_om import get_redis_connection, RedisModel


class TestRedisStorage(unittest.TestCase):
test_dbname = "gptcache_test"
url = "redis://default:default@localhost:6379"

def setUp(cls) -> None:
cls._clear_test_db()

@staticmethod
def _clear_test_db():
r = get_redis_connection()
r = get_redis_connection(url=TestRedisStorage.url)
r.flushall()
r.flushdb()
time.sleep(1)

def test_normal(self):
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
url=self.url)
data = []
for i in range(1, 10):
data.append(
Expand Down Expand Up @@ -61,7 +63,8 @@ def test_normal(self):
assert redis_storage.count(is_all=True) == 7

def test_with_deps(self):
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
url=self.url)
data_id = redis_storage.batch_insert(
[
CacheData(
Expand Down Expand Up @@ -98,7 +101,8 @@ def test_with_deps(self):
assert ret.question.deps[1].dep_type == 1

def test_create_on(self):
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
url=self.url)
redis_storage.create()
data = []
for i in range(1, 10):
Expand All @@ -124,7 +128,8 @@ def test_create_on(self):
assert last_access1 < last_access2

def test_session(self):
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
url=self.url)
data = []
for i in range(1, 11):
data.append(
Expand Down

0 comments on commit a5fc129

Please sign in to comment.