Skip to content

Commit

Permalink
Add pipeline support for search (#2038)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvora-h authored Mar 8, 2022
1 parent 1f2259f commit 5bf9034
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 16 deletions.
30 changes: 25 additions & 5 deletions redis/commands/search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import redis

from .commands import SearchCommands


Expand All @@ -17,7 +19,7 @@ def __init__(self, client, chunk_size=1000):

self.client = client
self.execute_command = client.execute_command
self.pipeline = client.pipeline(transaction=False, shard_hint=None)
self._pipeline = client.pipeline(transaction=False, shard_hint=None)
self.total = 0
self.chunk_size = chunk_size
self.current_chunk = 0
Expand All @@ -42,7 +44,7 @@ def add_document(
"""
self.client._add_document(
doc_id,
conn=self.pipeline,
conn=self._pipeline,
nosave=nosave,
score=score,
payload=payload,
Expand All @@ -67,7 +69,7 @@ def add_document_hash(
"""
self.client._add_document_hash(
doc_id,
conn=self.pipeline,
conn=self._pipeline,
score=score,
replace=replace,
)
Expand All @@ -80,7 +82,7 @@ def commit(self):
"""
Manually commit and flush the batch indexing query
"""
self.pipeline.execute()
self._pipeline.execute()
self.current_chunk = 0

def __init__(self, client, index_name="idx"):
Expand All @@ -90,7 +92,25 @@ def __init__(self, client, index_name="idx"):
If conn is not None, we employ an already existing redis connection
"""
self.MODULE_CALLBACKS = {}
self.client = client
self.index_name = index_name
self.execute_command = client.execute_command
self.pipeline = client.pipeline
self._pipeline = client.pipeline

def pipeline(self, transaction=True, shard_hint=None):
"""Creates a pipeline for the SEARCH module, that can be used for executing
SEARCH commands, as well as classic core commands.
"""
p = Pipeline(
connection_pool=self.client.connection_pool,
response_callbacks=self.MODULE_CALLBACKS,
transaction=transaction,
shard_hint=shard_hint,
)
p.index_name = self.index_name
return p


class Pipeline(SearchCommands, redis.client.Pipeline):
"""Pipeline for the module."""
31 changes: 20 additions & 11 deletions redis/commands/search/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import time
from typing import Dict, Union

from redis.client import Pipeline

from ..helpers import parse_to_dict
from ._util import to_string
from .aggregation import AggregateRequest, AggregateResult, Cursor
Expand Down Expand Up @@ -186,8 +188,6 @@ def _add_document(
"""
Internal add_document used for both batch and single doc indexing
"""
if conn is None:
conn = self.client

if partial or no_create:
replace = True
Expand All @@ -208,7 +208,11 @@ def _add_document(
args += ["LANGUAGE", language]
args.append("FIELDS")
args += list(itertools.chain(*fields.items()))
return conn.execute_command(*args)

if conn is not None:
return conn.execute_command(*args)

return self.execute_command(*args)

def _add_document_hash(
self,
Expand All @@ -221,8 +225,6 @@ def _add_document_hash(
"""
Internal add_document_hash used for both batch and single doc indexing
"""
if conn is None:
conn = self.client

args = [ADDHASH_CMD, self.index_name, doc_id, score]

Expand All @@ -232,7 +234,10 @@ def _add_document_hash(
if language:
args += ["LANGUAGE", language]

return conn.execute_command(*args)
if conn is not None:
return conn.execute_command(*args)

return self.execute_command(*args)

def add_document(
self,
Expand Down Expand Up @@ -331,12 +336,13 @@ def delete_document(self, doc_id, conn=None, delete_actual_document=False):
For more information: https://oss.redis.com/redisearch/Commands/#ftdel
""" # noqa
args = [DEL_CMD, self.index_name, doc_id]
if conn is None:
conn = self.client
if delete_actual_document:
args.append("DD")

return conn.execute_command(*args)
if conn is not None:
return conn.execute_command(*args)

return self.execute_command(*args)

def load_document(self, id):
"""
Expand Down Expand Up @@ -364,7 +370,7 @@ def get(self, *ids):
For more information https://oss.redis.com/redisearch/Commands/#ftget
"""

return self.client.execute_command(MGET_CMD, self.index_name, *ids)
return self.execute_command(MGET_CMD, self.index_name, *ids)

def info(self):
"""
Expand All @@ -374,7 +380,7 @@ def info(self):
For more information https://oss.redis.com/redisearch/Commands/#ftinfo
"""

res = self.client.execute_command(INFO_CMD, self.index_name)
res = self.execute_command(INFO_CMD, self.index_name)
it = map(to_string, res)
return dict(zip(it, it))

Expand Down Expand Up @@ -423,6 +429,9 @@ def search(
st = time.time()
res = self.execute_command(SEARCH_CMD, *args)

if isinstance(res, Pipeline):
return res

return Result(
res,
not query._no_content,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,3 +1577,21 @@ def test_geo_params(modclient):
assert "doc1" == res.docs[0].id
assert "doc2" == res.docs[1].id
assert "doc3" == res.docs[2].id


@pytest.mark.redismod
def test_search_commands_in_pipeline(client):
p = client.ft().pipeline()
p.create_index((TextField("txt"),))
p.add_document("doc1", payload="foo baz", txt="foo bar")
p.add_document("doc2", txt="foo bar")
q = Query("foo bar").with_payloads()
p.search(q)
res = p.execute()
assert res[:3] == ["OK", "OK", "OK"]
assert 2 == res[3][0]
assert "doc1" == res[3][1]
assert "doc2" == res[3][4]
assert "foo baz" == res[3][2]
assert res[3][5] is None
assert res[3][3] == res[3][6] == ["txt", "foo bar"]

0 comments on commit 5bf9034

Please sign in to comment.