Skip to content

Commit

Permalink
fix: infer
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Jun 23, 2024
1 parent 38012d0 commit 2c1868e
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 39 deletions.
4 changes: 2 additions & 2 deletions libs/infinity_emb/infinity_emb/transformer/utils_optimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def optimize_model(
if Path(model_name_or_path).exists()
else Path(HUGGINGFACE_HUB_CACHE) / "infinity_onnx" / model_name_or_path
)
files_optimized = list(path_folder.glob("**/*optimized.onnx"))
files_optimized = list(path_folder.glob(f"**/*{execution_provider}_optimized.onnx.onnx"))
if execution_provider == "TensorrtExecutionProvider":
return model_class.from_pretrained(
model_name_or_path,
Expand Down Expand Up @@ -142,7 +142,7 @@ def optimize_model(
revision=revision,
trust_remote_code=trust_remote_code,
provider=execution_provider,
file_name=Path(file_name).name.replace(".onnx", "_optimized.onnx"),
file_name=Path(file_name).name.replace(".onnx", f"{execution_provider}_optimized.onnx"),
)
except Exception as e:
logger.warning(
Expand Down
156 changes: 145 additions & 11 deletions libs/simpleinference/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/simpleinference/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ priority = "explicit"
[tool.poetry.dependencies]
python = ">=3.9,<4"
# infinity_emb = {path = "../infinity_emb", extras = ["optimum","vision"]}
infinity_emb = {version = "0.0.49", extras = ["optimum","vision"]}
infinity_emb = {version = "0.0.49", extras = ["optimum","vision","torch"]}
[tool.poetry.group.test.dependencies]
pytest = "^7.0.0"
coverage = {extras = ["toml"], version = "^7.3.2"}
Expand Down
50 changes: 25 additions & 25 deletions libs/simpleinference/simpleinference/infer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from concurrent.futures import Future
from typing import Iterable, Literal, Union
from typing import Collection, Literal, Union

from infinity_emb import EngineArgs, SyncEngineArray # type: ignore
from infinity_emb.infinity_server import AutoPadding

__all__ = ["SimpleInference"]

Expand All @@ -16,18 +17,17 @@ class SimpleInference:
def __init__(
self,
*,
model_id: Union[ModelID, Iterable[ModelID]],
engine: Union[Engine, Iterable[Engine]] = "optimum",
device: Union[Device, Iterable[Device]] = "cpu",
embedding_dtype: Union[EmbeddingDtype, Iterable[EmbeddingDtype]] = "float32",
model_id: Union[ModelID, Collection[ModelID]],
engine: Union[Engine, Collection[Engine]] = "optimum",
device: Union[Device, Collection[Device]] = "cpu",
embedding_dtype: Union[EmbeddingDtype, Collection[EmbeddingDtype]] = "float32",
):
"""An easy interface to infer with multiple models.
>>> ei = SimpleInference(model_id="michaelfeil/bge-small-en-v1.5")
>>> ei = SimpleInference(model_id=['michaelfeil/bge-small-en-v1.5','mixedbread-ai/mxbai-rerank-xsmall-v1'])
>>> ei
SimpleInference(['michaelfeil/bge-small-en-v1.5'])
>>> ei.stop()
SimpleInference(['michaelfeil/bge-small-en-v1.5', 'mixedbread-ai/mxbai-rerank-xsmall-v1'])
>>> ei.stop() # always stop when you are done
"""

if isinstance(model_id, str):
model_id = [model_id]
if isinstance(engine, str):
Expand All @@ -36,21 +36,21 @@ def __init__(
device = [device]
if isinstance(embedding_dtype, str):
embedding_dtype = [embedding_dtype]
self._engine_args = [
EngineArgs(
model_name_or_path=m,
engine=e, # type: ignore
device=d, # type: ignore
served_model_name=m,
embedding_dtype=edt, # type: ignore
lengths_via_tokenize=True,
model_warmup=False,
)
for m, e, d, edt in zip(model_id, engine, device, embedding_dtype)
]
self._engine_array = SyncEngineArray.from_args(
engine_args_array=self._engine_args
EngineArgs()
pad = AutoPadding(
length=len(model_id),
# pass through arguments
model_name_or_path=model_id,
engine=engine,
device=device,
embedding_dtype=embedding_dtype,
# optinionated defaults
lengths_via_tokenize=True,
model_warmup=True,
trust_remote_code=True,
)
self._engine_args = [EngineArgs(**kwargs) for kwargs in pad]
self._engine_array = SyncEngineArray.from_args(self._engine_args)

def stop(self):
self._engine_array.stop()
Expand All @@ -66,14 +66,14 @@ def embed(
) -> Future[tuple[list[list[float]], int]]:
"""Embed sentences with a model.
>>> ei = SimpleInference(model_id="michaelfeil/bge-small-en-v1.5")
>>> ei = SimpleInference(model_id="michaelfeil/bge-small-en-v1.5", engine="torch")
>>> embed_result = ei.embed(model_id="michaelfeil/bge-small-en-v1.5", sentences=["Hello, world!"])
>>> type(embed_result)
<class 'concurrent.futures._base.Future'>
>>> embed_result.result()[0][0].shape # embedding
(384,)
>>> embed_result.result()[1] # embedding and usage of 6 tokens
6
4
>>> ei.stop()
"""
return self._engine_array.embed(model=model_id, sentences=sentences)
Expand Down

0 comments on commit 2c1868e

Please sign in to comment.