From dca16d51fdb4d6c399ceee2af3130e248039df72 Mon Sep 17 00:00:00 2001 From: filipecosta90 Date: Thu, 3 Oct 2024 14:35:21 +0100 Subject: [PATCH] Adjust code for re-indexing --- engine/clients/redis/config.py | 2 +- engine/clients/redis/configure.py | 5 +- engine/clients/redis/upload.py | 88 ++++++++++++++++++------------- 3 files changed, 53 insertions(+), 42 deletions(-) diff --git a/engine/clients/redis/config.py b/engine/clients/redis/config.py index e322511e..8fcf04c2 100644 --- a/engine/clients/redis/config.py +++ b/engine/clients/redis/config.py @@ -5,7 +5,7 @@ REDIS_USER = os.getenv("REDIS_USER", None) REDIS_CLUSTER = bool(int(os.getenv("REDIS_CLUSTER", 0))) REDIS_HYBRID_POLICY = os.getenv("REDIS_HYBRID_POLICY", None) -REDIS_KEEP_DOCUMENTS = os.getenv("REDIS_KEEP_DOCUMENTS", 0) +REDIS_KEEP_DOCUMENTS = bool(os.getenv("REDIS_KEEP_DOCUMENTS", 0)) GPU_STATS = bool(int(os.getenv("GPU_STATS", 0))) GPU_STATS_ENDPOINT = os.getenv("GPU_STATS_ENDPOINT", None) diff --git a/engine/clients/redis/configure.py b/engine/clients/redis/configure.py index 49992b25..edcb20ae 100644 --- a/engine/clients/redis/configure.py +++ b/engine/clients/redis/configure.py @@ -53,10 +53,7 @@ def clean(self): for conn in conns: index = conn.ft() try: - if REDIS_KEEP_DOCUMENTS: - index.dropindex(delete_documents=False) - else: - index.dropindex(delete_documents=True) + index.dropindex(delete_documents=(not REDIS_KEEP_DOCUMENTS)) except redis.ResponseError as e: str_err = e.__str__() if ( diff --git a/engine/clients/redis/upload.py b/engine/clients/redis/upload.py index 80fd12a0..39ecfa65 100644 --- a/engine/clients/redis/upload.py +++ b/engine/clients/redis/upload.py @@ -3,7 +3,7 @@ from ml_dtypes import bfloat16 import requests import json - +import random import numpy as np from redis import Redis, RedisCluster from engine.base_client.upload import BaseUploader @@ -14,6 +14,7 @@ REDIS_CLUSTER, GPU_STATS, GPU_STATS_ENDPOINT, + REDIS_KEEP_DOCUMENTS, ) from engine.clients.redis.helper import convert_to_redis_coords @@ -48,46 +49,50 @@ def init_client(cls, host, distance, connection_params, upload_params): cls.np_data_type = np.float16 if cls.data_type == "BFLOAT16": cls.np_data_type = bfloat16 + cls._is_cluster = True if REDIS_CLUSTER else False @classmethod def upload_batch( cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]] ): - p = cls.client.pipeline(transaction=False) - for i in range(len(ids)): - idx = ids[i] - vec = vectors[i] - meta = metadata[i] if metadata else {} - geopoints = {} - payload = {} - if meta is not None: - for k, v in meta.items(): - # This is a patch for arxiv-titles dataset where we have a list of "labels", and - # we want to index all of them under the same TAG field (whose separator is ';'). - if k == "labels": - payload[k] = ";".join(v) - if ( - v is not None - and not isinstance(v, dict) - and not isinstance(v, list) - ): - payload[k] = v - # Redis treats geopoints differently and requires putting them as - # a comma-separated string with lat and lon coordinates - geopoints = { - k: ",".join(map(str, convert_to_redis_coords(v["lon"], v["lat"]))) - for k, v in meta.items() - if isinstance(v, dict) - } - cls.client.hset( - str(idx), - mapping={ - "vector": np.array(vec).astype(cls.np_data_type).tobytes(), - **payload, - **geopoints, - }, - ) - p.execute() + # if we don't delete the docs we can skip sending them again + # By default we always send the docs + if REDIS_KEEP_DOCUMENTS is False: + p = cls.client.pipeline(transaction=False) + for i in range(len(ids)): + idx = ids[i] + vec = vectors[i] + meta = metadata[i] if metadata else {} + geopoints = {} + payload = {} + if meta is not None: + for k, v in meta.items(): + # This is a patch for arxiv-titles dataset where we have a list of "labels", and + # we want to index all of them under the same TAG field (whose separator is ';'). + if k == "labels": + payload[k] = ";".join(v) + if ( + v is not None + and not isinstance(v, dict) + and not isinstance(v, list) + ): + payload[k] = v + # Redis treats geopoints differently and requires putting them as + # a comma-separated string with lat and lon coordinates + geopoints = { + k: ",".join(map(str, convert_to_redis_coords(v["lon"], v["lat"]))) + for k, v in meta.items() + if isinstance(v, dict) + } + cls.client.hset( + str(idx), + mapping={ + "vector": np.array(vec).astype(cls.np_data_type).tobytes(), + **payload, + **geopoints, + }, + ) + p.execute() @classmethod def post_upload(cls, _distance): @@ -120,7 +125,16 @@ def post_upload(cls, _distance): return {} def get_memory_usage(cls): - used_memory = cls.client_decode.info("memory")["used_memory"] + used_memory = [] + conns = [cls.client_decode] + if cls._is_cluster: + conns = [ + cls.client_decode.get_redis_connection(node) + for node in cls.client_decode.get_primaries() + ] + for conn in conns: + used_memory_shard = conn.info("memory")["used_memory"] + used_memory.append(used_memory_shard) index_info = {} device_info = {} if cls.algorithm != "HNSW" and cls.algorithm != "FLAT":