Skip to content

Commit

Permalink
Local inference upload collection and upload records (#862)
Browse files Browse the repository at this point in the history
* wip: add draft implementation of batch processing

* fix: embed dict and list of docs, remove redundant code

* new: regen async, small refactor

* refactor: add docstrings, rename methods

* Upload points local inference (#881)

* new: separate single and plural model embeddings

* fix: fix lazy embed models

* new: add inference object inspections to upload methods

* wip: local inference upload parallel

* new: add local inference to upload points and upload collection, refactor mixin

* fix: remove redundant code

* redundant import

* tests: check is query for query points batch

* refactor: refactor semi ordered map

* tests: add test for local inference with batches with docs and vectors

* tests: check the order of dict processing

* new: distinguish models by options

* fix: fix typing

* fix: fix types

* new: embed batches with different options

* tests: add tests for batch with different options

* fix: ignore ide incorrect type inspection

* tests: wait for points to be inserted

* fix: set threads to 1 in parallel inference

* new: adjust max internal batch size

* fix: fix type hints

* function to get embeddings size (#892)

* function to get embeddings size

* async client

* keep sync

* new: extend embedding size to support image and late interaction models

---------

Co-authored-by: George Panchuk <[email protected]>

* new: add local inference batch size (#894)

---------

Co-authored-by: Andrey Vasnetsov <[email protected]>
  • Loading branch information
joein and generall authored Jan 29, 2025
1 parent 188f3c6 commit b4ae703
Show file tree
Hide file tree
Showing 16 changed files with 2,270 additions and 986 deletions.
134 changes: 120 additions & 14 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
from copy import deepcopy
from typing import Any, Awaitable, Callable, Iterable, Mapping, Optional, Sequence, Union
import numpy as np
from qdrant_client import grpc as grpc
from qdrant_client.async_client_base import AsyncQdrantBase
from qdrant_client.common.client_warnings import show_warning_once
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
cloud_inference: bool = False,
local_inference_batch_size: Optional[int] = None,
check_compatibility: bool = True,
**kwargs: Any,
):
Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
"Cloud inference is not supported for local Qdrant, consider using FastEmbed or switch to Qdrant Cloud"
)
self.cloud_inference = cloud_inference
self.local_inference_batch_size = local_inference_batch_size

async def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None:
"""Closes the connection to Qdrant
Expand Down Expand Up @@ -395,7 +398,11 @@ async def query_batch_points(
requests = self._resolve_query_batch_request(requests)
requires_inference = self._inference_inspector.inspect(requests)
if requires_inference and (not self.cloud_inference):
requests = [self._embed_models(request) for request in requests]
requests = list(
self._embed_models(
requests, is_query=True, batch_size=self.local_inference_batch_size
)
)
return await self._client.query_batch_points(
collection_name=collection_name,
requests=requests,
Expand Down Expand Up @@ -522,10 +529,35 @@ async def query_points(
query = self._resolve_query(query)
requires_inference = self._inference_inspector.inspect([query, prefetch])
if requires_inference and (not self.cloud_inference):
query = self._embed_models(query, is_query=True) if query is not None else None
prefetch = (
self._embed_models(prefetch, is_query=True) if prefetch is not None else None
query = (
next(
iter(
self._embed_models(
query, is_query=True, batch_size=self.local_inference_batch_size
)
)
)
if query is not None
else None
)
if isinstance(prefetch, list):
prefetch = list(
self._embed_models(
prefetch, is_query=True, batch_size=self.local_inference_batch_size
)
)
else:
prefetch = (
next(
iter(
self._embed_models(
prefetch, is_query=True, batch_size=self.local_inference_batch_size
)
)
)
if prefetch is not None
else None
)
return await self._client.query_points(
collection_name=collection_name,
query=query,
Expand Down Expand Up @@ -661,10 +693,31 @@ async def query_points_groups(
query = self._resolve_query(query)
requires_inference = self._inference_inspector.inspect([query, prefetch])
if requires_inference and (not self.cloud_inference):
query = self._embed_models(query, is_query=True) if query is not None else None
prefetch = (
self._embed_models(prefetch, is_query=True) if prefetch is not None else None
query = (
next(
iter(
self._embed_models(
query, is_query=True, batch_size=self.local_inference_batch_size
)
)
)
if query is not None
else None
)
if isinstance(prefetch, list):
prefetch = list(
self._embed_models(
prefetch, is_query=True, batch_size=self.local_inference_batch_size
)
)
elif prefetch is not None:
prefetch = next(
iter(
self._embed_models(
prefetch, is_query=True, batch_size=self.local_inference_batch_size
)
)
)
return await self._client.query_points_groups(
collection_name=collection_name,
query=query,
Expand Down Expand Up @@ -1506,10 +1559,20 @@ async def upsert(
)
requires_inference = self._inference_inspector.inspect(points)
if requires_inference and (not self.cloud_inference):
if isinstance(points, list):
points = [self._embed_models(point, is_query=False) for point in points]
if isinstance(points, types.Batch):
points = next(
iter(
self._embed_models(
points, is_query=False, batch_size=self.local_inference_batch_size
)
)
)
else:
points = self._embed_models(points, is_query=False)
points = list(
self._embed_models(
points, is_query=False, batch_size=self.local_inference_batch_size
)
)
return await self._client.upsert(
collection_name=collection_name,
points=points,
Expand Down Expand Up @@ -1560,7 +1623,11 @@ async def update_vectors(
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
requires_inference = self._inference_inspector.inspect(points)
if requires_inference and (not self.cloud_inference):
points = [self._embed_models(point, is_query=False) for point in points]
points = list(
self._embed_models(
points, is_query=False, batch_size=self.local_inference_batch_size
)
)
return await self._client.update_vectors(
collection_name=collection_name,
points=points,
Expand Down Expand Up @@ -2000,9 +2067,11 @@ async def batch_update_points(
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
requires_inference = self._inference_inspector.inspect(update_operations)
if requires_inference and (not self.cloud_inference):
update_operations = [
self._embed_models(op, is_query=False) for op in update_operations
]
update_operations = list(
self._embed_models(
update_operations, is_query=False, batch_size=self.local_inference_batch_size
)
)
return await self._client.batch_update_points(
collection_name=collection_name,
update_operations=update_operations,
Expand Down Expand Up @@ -2426,7 +2495,25 @@ def upload_points(
This parameter overwrites shard keys written in the records.
"""

def chain(*iterables: Iterable) -> Iterable:
for iterable in iterables:
yield from iterable

assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
if not self.cloud_inference:
iter_points = iter(points)
requires_inference = False
try:
point = next(iter_points)
requires_inference = self._inference_inspector.inspect(point)
points = chain(iter([point]), iter_points)
except (StopIteration, StopAsyncIteration):
points = []
if requires_inference:
points = self._embed_models_strict(
points, parallel=parallel, batch_size=self.local_inference_batch_size
)
return self._client.upload_points(
collection_name=collection_name,
points=points,
Expand Down Expand Up @@ -2478,7 +2565,26 @@ def upload_collection(
If multiple shard_keys are provided, the update will be written to each of them.
Only works for collections with `custom` sharding method.
"""

def chain(*iterables: Iterable) -> Iterable:
for iterable in iterables:
yield from iterable

assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
if not self.cloud_inference:
if not isinstance(vectors, dict) and (not isinstance(vectors, np.ndarray)):
requires_inference = False
try:
iter_vectors = iter(vectors)
vector = next(iter_vectors)
requires_inference = self._inference_inspector.inspect(vector)
vectors = chain(iter([vector]), iter_vectors)
except (StopIteration, StopAsyncIteration):
vectors = []
if requires_inference:
vectors = self._embed_models_strict(
vectors, parallel=parallel, batch_size=self.local_inference_batch_size
)
return self._client.upload_collection(
collection_name=collection_name,
vectors=vectors,
Expand Down
Loading

0 comments on commit b4ae703

Please sign in to comment.