Skip to content

Commit

Permalink
Adjust code for re-indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
filipecosta90 committed Oct 3, 2024
1 parent aa701aa commit dca16d5
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 42 deletions.
2 changes: 1 addition & 1 deletion engine/clients/redis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
REDIS_USER = os.getenv("REDIS_USER", None)
REDIS_CLUSTER = bool(int(os.getenv("REDIS_CLUSTER", 0)))
REDIS_HYBRID_POLICY = os.getenv("REDIS_HYBRID_POLICY", None)
REDIS_KEEP_DOCUMENTS = os.getenv("REDIS_KEEP_DOCUMENTS", 0)
REDIS_KEEP_DOCUMENTS = bool(os.getenv("REDIS_KEEP_DOCUMENTS", 0))
GPU_STATS = bool(int(os.getenv("GPU_STATS", 0)))
GPU_STATS_ENDPOINT = os.getenv("GPU_STATS_ENDPOINT", None)

Expand Down
5 changes: 1 addition & 4 deletions engine/clients/redis/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def clean(self):
for conn in conns:
index = conn.ft()
try:
if REDIS_KEEP_DOCUMENTS:
index.dropindex(delete_documents=False)
else:
index.dropindex(delete_documents=True)
index.dropindex(delete_documents=(not REDIS_KEEP_DOCUMENTS))
except redis.ResponseError as e:
str_err = e.__str__()
if (
Expand Down
88 changes: 51 additions & 37 deletions engine/clients/redis/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ml_dtypes import bfloat16
import requests
import json

import random
import numpy as np
from redis import Redis, RedisCluster
from engine.base_client.upload import BaseUploader
Expand All @@ -14,6 +14,7 @@
REDIS_CLUSTER,
GPU_STATS,
GPU_STATS_ENDPOINT,
REDIS_KEEP_DOCUMENTS,
)
from engine.clients.redis.helper import convert_to_redis_coords

Expand Down Expand Up @@ -48,46 +49,50 @@ def init_client(cls, host, distance, connection_params, upload_params):
cls.np_data_type = np.float16
if cls.data_type == "BFLOAT16":
cls.np_data_type = bfloat16
cls._is_cluster = True if REDIS_CLUSTER else False

@classmethod
def upload_batch(
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
):
p = cls.client.pipeline(transaction=False)
for i in range(len(ids)):
idx = ids[i]
vec = vectors[i]
meta = metadata[i] if metadata else {}
geopoints = {}
payload = {}
if meta is not None:
for k, v in meta.items():
# This is a patch for arxiv-titles dataset where we have a list of "labels", and
# we want to index all of them under the same TAG field (whose separator is ';').
if k == "labels":
payload[k] = ";".join(v)
if (
v is not None
and not isinstance(v, dict)
and not isinstance(v, list)
):
payload[k] = v
# Redis treats geopoints differently and requires putting them as
# a comma-separated string with lat and lon coordinates
geopoints = {
k: ",".join(map(str, convert_to_redis_coords(v["lon"], v["lat"])))
for k, v in meta.items()
if isinstance(v, dict)
}
cls.client.hset(
str(idx),
mapping={
"vector": np.array(vec).astype(cls.np_data_type).tobytes(),
**payload,
**geopoints,
},
)
p.execute()
# if we don't delete the docs we can skip sending them again
# By default we always send the docs
if REDIS_KEEP_DOCUMENTS is False:
p = cls.client.pipeline(transaction=False)
for i in range(len(ids)):
idx = ids[i]
vec = vectors[i]
meta = metadata[i] if metadata else {}
geopoints = {}
payload = {}
if meta is not None:
for k, v in meta.items():
# This is a patch for arxiv-titles dataset where we have a list of "labels", and
# we want to index all of them under the same TAG field (whose separator is ';').
if k == "labels":
payload[k] = ";".join(v)
if (
v is not None
and not isinstance(v, dict)
and not isinstance(v, list)
):
payload[k] = v
# Redis treats geopoints differently and requires putting them as
# a comma-separated string with lat and lon coordinates
geopoints = {
k: ",".join(map(str, convert_to_redis_coords(v["lon"], v["lat"])))
for k, v in meta.items()
if isinstance(v, dict)
}
cls.client.hset(
str(idx),
mapping={
"vector": np.array(vec).astype(cls.np_data_type).tobytes(),
**payload,
**geopoints,
},
)
p.execute()

@classmethod
def post_upload(cls, _distance):
Expand Down Expand Up @@ -120,7 +125,16 @@ def post_upload(cls, _distance):
return {}

def get_memory_usage(cls):
used_memory = cls.client_decode.info("memory")["used_memory"]
used_memory = []
conns = [cls.client_decode]
if cls._is_cluster:
conns = [
cls.client_decode.get_redis_connection(node)
for node in cls.client_decode.get_primaries()
]
for conn in conns:
used_memory_shard = conn.info("memory")["used_memory"]
used_memory.append(used_memory_shard)
index_info = {}
device_info = {}
if cls.algorithm != "HNSW" and cls.algorithm != "FLAT":
Expand Down

0 comments on commit dca16d5

Please sign in to comment.