Skip to content

Commit

Permalink
new: add local inference batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Jan 29, 2025
1 parent 48907e6 commit 1ba8fa5
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 47 deletions.
86 changes: 73 additions & 13 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
Union[Callable[[], str], Callable[[], Awaitable[str]]]
] = None,
cloud_inference: bool = False,
local_inference_batch_size: Optional[int] = None,
check_compatibility: bool = True,
**kwargs: Any,
):
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
"Cloud inference is not supported for local Qdrant, consider using FastEmbed or switch to Qdrant Cloud"
)
self.cloud_inference = cloud_inference
self.local_inference_batch_size = local_inference_batch_size

async def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None:
"""Closes the connection to Qdrant
Expand Down Expand Up @@ -396,7 +398,11 @@ async def query_batch_points(
requests = self._resolve_query_batch_request(requests)
requires_inference = self._inference_inspector.inspect(requests)
if requires_inference and (not self.cloud_inference):
requests = list(self._embed_models(requests, is_query=True))
requests = list(
self._embed_models(
requests, is_query=True, batch_size=self.local_inference_batch_size
)
)
return await self._client.query_batch_points(
collection_name=collection_name,
requests=requests,
Expand Down Expand Up @@ -524,13 +530,31 @@ async def query_points(
requires_inference = self._inference_inspector.inspect([query, prefetch])
if requires_inference and (not self.cloud_inference):
query = (
next(iter(self._embed_models(query, is_query=True))) if query is not None else None
next(
iter(
self._embed_models(
query, is_query=True, batch_size=self.local_inference_batch_size
)
)
)
if query is not None
else None
)
if isinstance(prefetch, list):
prefetch = list(self._embed_models(prefetch, is_query=True))
prefetch = list(
self._embed_models(
prefetch, is_query=True, batch_size=self.local_inference_batch_size
)
)
else:
prefetch = (
next(iter(self._embed_models(prefetch, is_query=True)))
next(
iter(
self._embed_models(
prefetch, is_query=True, batch_size=self.local_inference_batch_size
)
)
)
if prefetch is not None
else None
)
Expand Down Expand Up @@ -670,12 +694,30 @@ async def query_points_groups(
requires_inference = self._inference_inspector.inspect([query, prefetch])
if requires_inference and (not self.cloud_inference):
query = (
next(iter(self._embed_models(query, is_query=True))) if query is not None else None
next(
iter(
self._embed_models(
query, is_query=True, batch_size=self.local_inference_batch_size
)
)
)
if query is not None
else None
)
if isinstance(prefetch, list):
prefetch = list(self._embed_models(prefetch, is_query=True))
prefetch = list(
self._embed_models(
prefetch, is_query=True, batch_size=self.local_inference_batch_size
)
)
elif prefetch is not None:
prefetch = next(iter(self._embed_models(prefetch, is_query=True)))
prefetch = next(
iter(
self._embed_models(
prefetch, is_query=True, batch_size=self.local_inference_batch_size
)
)
)
return await self._client.query_points_groups(
collection_name=collection_name,
query=query,
Expand Down Expand Up @@ -1518,9 +1560,19 @@ async def upsert(
requires_inference = self._inference_inspector.inspect(points)
if requires_inference and (not self.cloud_inference):
if isinstance(points, types.Batch):
points = next(iter(self._embed_models(points, is_query=False)))
points = next(
iter(
self._embed_models(
points, is_query=False, batch_size=self.local_inference_batch_size
)
)
)
else:
points = list(self._embed_models(points, is_query=False))
points = list(
self._embed_models(
points, is_query=False, batch_size=self.local_inference_batch_size
)
)
return await self._client.upsert(
collection_name=collection_name,
points=points,
Expand Down Expand Up @@ -1571,7 +1623,11 @@ async def update_vectors(
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
requires_inference = self._inference_inspector.inspect(points)
if requires_inference and (not self.cloud_inference):
points = list(self._embed_models(points, is_query=False))
points = list(
self._embed_models(
points, is_query=False, batch_size=self.local_inference_batch_size
)
)
return await self._client.update_vectors(
collection_name=collection_name,
points=points,
Expand Down Expand Up @@ -2011,7 +2067,11 @@ async def batch_update_points(
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
requires_inference = self._inference_inspector.inspect(update_operations)
if requires_inference and (not self.cloud_inference):
update_operations = list(self._embed_models(update_operations, is_query=False))
update_operations = list(
self._embed_models(
update_operations, is_query=False, batch_size=self.local_inference_batch_size
)
)
return await self._client.batch_update_points(
collection_name=collection_name,
update_operations=update_operations,
Expand Down Expand Up @@ -2452,7 +2512,7 @@ def chain(*iterables: Iterable) -> Iterable:
points = []
if requires_inference:
points = self._embed_models_strict(
points, parallel=parallel, batch_size=batch_size
points, parallel=parallel, batch_size=self.local_inference_batch_size
)
return self._client.upload_points(
collection_name=collection_name,
Expand Down Expand Up @@ -2523,7 +2583,7 @@ def chain(*iterables: Iterable) -> Iterable:
vectors = []
if requires_inference:
vectors = self._embed_models_strict(
vectors, parallel=parallel, batch_size=batch_size
vectors, parallel=parallel, batch_size=self.local_inference_batch_size
)
return self._client.upload_collection(
collection_name=collection_name,
Expand Down
15 changes: 10 additions & 5 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@

class AsyncQdrantFastembedMixin(AsyncQdrantBase):
DEFAULT_EMBEDDING_MODEL = "BAAI/bge-small-en"
DEFAULT_BATCH_SIZE = 16
_FASTEMBED_INSTALLED: bool

def __init__(self, parser: ModelSchemaParser, **kwargs: Any):
self._embedding_model_name: Optional[str] = None
self._sparse_embedding_model_name: Optional[str] = None
self._model_embedder = ModelEmbedder(parser=parser)
self._model_embedder = ModelEmbedder(parser=parser, **kwargs)
try:
from fastembed import SparseTextEmbedding, TextEmbedding

Expand Down Expand Up @@ -804,18 +805,22 @@ def _embed_models(
self,
raw_models: Union[BaseModel, Iterable[BaseModel]],
is_query: bool = False,
batch_size: int = 32,
batch_size: Optional[int] = None,
) -> Iterable[BaseModel]:
yield from self._model_embedder.embed_models(
raw_models=raw_models, is_query=is_query, batch_size=batch_size
raw_models=raw_models,
is_query=is_query,
batch_size=batch_size or self.DEFAULT_BATCH_SIZE,
)

def _embed_models_strict(
self,
raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]],
batch_size: int = 32,
batch_size: Optional[int] = None,
parallel: Optional[int] = None,
) -> Iterable[BaseModel]:
yield from self._model_embedder.embed_models_strict(
raw_models=raw_models, batch_size=batch_size, parallel=parallel
raw_models=raw_models,
batch_size=batch_size or self.DEFAULT_BATCH_SIZE,
parallel=parallel,
)
3 changes: 1 addition & 2 deletions qdrant_client/embed/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def get_or_init_model(
"device_ids": device_ids,
**kwargs,
}

for instance in self.embedding_models[model_name]:
if (deprecated and instance.deprecated) or (
not deprecated and instance.options == options
Expand Down Expand Up @@ -193,9 +192,9 @@ def embed(
) -> NumericVector:
if (texts is None) is (images is None):
raise ValueError("Either documents or images should be provided")

if model_name in SUPPORTED_EMBEDDING_MODELS:
embedding_model_inst = self.get_or_init_model(model_name=model_name, **options or {})

if not is_query:
embeddings = [
embedding.tolist()
Expand Down
14 changes: 8 additions & 6 deletions qdrant_client/embed/model_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]


class ModelEmbedder:
MAX_INTERNAL_BATCH_SIZE = 4
MAX_INTERNAL_BATCH_SIZE = 64

def __init__(self, parser: Optional[ModelSchemaParser] = None, **kwargs: Any):
self._batch_accumulator: dict[str, list[INFERENCE_OBJECT_TYPES]] = {}
Expand All @@ -49,7 +49,7 @@ def embed_models(
self,
raw_models: Union[BaseModel, Iterable[BaseModel]],
is_query: bool = False,
batch_size: int = 32,
batch_size: int = 16,
) -> Iterable[BaseModel]:
"""Embed raw data fields in models and return models with vectors
Expand All @@ -70,7 +70,7 @@ def embed_models(
def embed_models_strict(
self,
raw_models: Iterable[Union[dict[str, BaseModel], BaseModel]],
batch_size: int = 32,
batch_size: int = 16,
parallel: Optional[int] = None,
) -> Iterable[Union[dict[str, BaseModel], BaseModel]]:
"""Embed raw data fields in models and return models with vectors
Expand All @@ -92,12 +92,14 @@ def embed_models_strict(
if len(raw_models) < batch_size:
is_small = True

raw_models_batches = iter_batch(raw_models, batch_size)

if parallel is None or parallel == 1 or is_small:
for batch in raw_models_batches:
for batch in iter_batch(raw_models, batch_size):
yield from self.embed_models_batch(batch)
else:
raw_models_batches = iter_batch(
raw_models, size=1
) # larger batch sizes do not help with data parallel
# on cpu. todo: adjust when multi-gpu is available
if parallel == 0:
parallel = os.cpu_count()

Expand Down
Loading

0 comments on commit 1ba8fa5

Please sign in to comment.