diff --git a/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py b/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py index f45553f9..616fdaae 100644 --- a/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py +++ b/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py @@ -76,7 +76,7 @@ def optimize_model( if Path(model_name_or_path).exists() else Path(HUGGINGFACE_HUB_CACHE) / "infinity_onnx" / model_name_or_path ) - files_optimized = list(path_folder.glob("**/*optimized.onnx")) + files_optimized = list(path_folder.glob(f"**/*{execution_provider}_optimized.onnx.onnx")) if execution_provider == "TensorrtExecutionProvider": return model_class.from_pretrained( model_name_or_path, @@ -142,7 +142,7 @@ def optimize_model( revision=revision, trust_remote_code=trust_remote_code, provider=execution_provider, - file_name=Path(file_name).name.replace(".onnx", "_optimized.onnx"), + file_name=Path(file_name).name.replace(".onnx", f"{execution_provider}_optimized.onnx"), ) except Exception as e: logger.warning( diff --git a/libs/simpleinference/poetry.lock b/libs/simpleinference/poetry.lock index eb0be067..9e0362c1 100644 --- a/libs/simpleinference/poetry.lock +++ b/libs/simpleinference/poetry.lock @@ -936,17 +936,21 @@ name = "infinity-emb" version = "0.0.49" description = "Infinity is a high-throughput, low-latency REST API for serving vector embeddings, supporting a wide range of sentence-transformer models and frameworks." optional = false -python-versions = ">=3.9,<4" -files = [] -develop = false +python-versions = "<4,>=3.9" +files = [ + {file = "infinity_emb-0.0.49-py3-none-any.whl", hash = "sha256:377a2c08f8ca4a4e992c7fba4680fd1aaadc4cce1db61d1015802f4d2c4cf543"}, + {file = "infinity_emb-0.0.49.tar.gz", hash = "sha256:6f0fc9c61d8ca342a501db8cfe095ef0875eb87d34257c57876f047fff2a14f9"}, +] [package.dependencies] hf_transfer = ">=0.1.5" huggingface_hub = "*" numpy = ">=1.20.0,<2" -optimum = {version = ">=1.16.2", extras = ["onnxruntime"], optional = true} -pillow = {version = "*", optional = true} -timm = {version = "*", optional = true} +optimum = {version = ">=1.16.2", extras = ["onnxruntime"], optional = true, markers = "extra == \"optimum\" or extra == \"all\""} +pillow = {version = "*", optional = true, markers = "extra == \"vision\" or extra == \"all\""} +sentence-transformers = {version = ">=3.0.1,<4.0.0", optional = true, markers = "extra == \"ct2\" or extra == \"torch\" or extra == \"all\""} +timm = {version = "*", optional = true, markers = "extra == \"vision\" or extra == \"all\""} +torch = {version = ">=2.2.1", optional = true, markers = "extra == \"ct2\" or extra == \"torch\" or extra == \"all\""} [package.extras] all = ["ctranslate2 (>=4.0.0,<5.0.0)", "diskcache", "einops", "fastapi (>=0.103.2)", "optimum[onnxruntime] (>=1.16.2)", "orjson (>=3.9.8,!=3.10.0)", "pillow", "prometheus-fastapi-instrumentator (>=6.1.0)", "pydantic (>=2.4.0,<3)", "rich (>=13,<14)", "sentence-transformers (>=3.0.1,<4.0.0)", "timm", "torch (>=2.2.1)", "typer[all] (>=0.9.0,<0.10.0)", "uvicorn[standard] (>=0.23.2,<0.24.0)"] @@ -961,10 +965,6 @@ tensorrt = ["tensorrt (>=8.6.1,<9.0.0)"] torch = ["sentence-transformers (>=3.0.1,<4.0.0)", "torch (>=2.2.1)"] vision = ["pillow", "timm"] -[package.source] -type = "directory" -url = "../infinity_emb" - [[package]] name = "iniconfig" version = "2.0.0" @@ -1007,6 +1007,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "markupsafe" version = "2.1.5" @@ -2318,6 +2329,118 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "scikit-learn" +version = "1.5.0" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:12e40ac48555e6b551f0a0a5743cc94cc5a765c9513fe708e01f0aa001da2801"}, + {file = "scikit_learn-1.5.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f405c4dae288f5f6553b10c4ac9ea7754d5180ec11e296464adb5d6ac68b6ef5"}, + {file = "scikit_learn-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df8ccabbf583315f13160a4bb06037bde99ea7d8211a69787a6b7c5d4ebb6fc3"}, + {file = "scikit_learn-1.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c75ea812cd83b1385bbfa94ae971f0d80adb338a9523f6bbcb5e0b0381151d4"}, + {file = "scikit_learn-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:a90c5da84829a0b9b4bf00daf62754b2be741e66b5946911f5bdfaa869fcedd6"}, + {file = "scikit_learn-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2a65af2d8a6cce4e163a7951a4cfbfa7fceb2d5c013a4b593686c7f16445cf9d"}, + {file = "scikit_learn-1.5.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:4c0c56c3005f2ec1db3787aeaabefa96256580678cec783986836fc64f8ff622"}, + {file = "scikit_learn-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f77547165c00625551e5c250cefa3f03f2fc92c5e18668abd90bfc4be2e0bff"}, + {file = "scikit_learn-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:118a8d229a41158c9f90093e46b3737120a165181a1b58c03461447aa4657415"}, + {file = "scikit_learn-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:a03b09f9f7f09ffe8c5efffe2e9de1196c696d811be6798ad5eddf323c6f4d40"}, + {file = "scikit_learn-1.5.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:460806030c666addee1f074788b3978329a5bfdc9b7d63e7aad3f6d45c67a210"}, + {file = "scikit_learn-1.5.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:1b94d6440603752b27842eda97f6395f570941857456c606eb1d638efdb38184"}, + {file = "scikit_learn-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d82c2e573f0f2f2f0be897e7a31fcf4e73869247738ab8c3ce7245549af58ab8"}, + {file = "scikit_learn-1.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3a10e1d9e834e84d05e468ec501a356226338778769317ee0b84043c0d8fb06"}, + {file = "scikit_learn-1.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:855fc5fa8ed9e4f08291203af3d3e5fbdc4737bd617a371559aaa2088166046e"}, + {file = "scikit_learn-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:40fb7d4a9a2db07e6e0cae4dc7bdbb8fada17043bac24104d8165e10e4cff1a2"}, + {file = "scikit_learn-1.5.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:47132440050b1c5beb95f8ba0b2402bbd9057ce96ec0ba86f2f445dd4f34df67"}, + {file = "scikit_learn-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:174beb56e3e881c90424e21f576fa69c4ffcf5174632a79ab4461c4c960315ac"}, + {file = "scikit_learn-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:261fe334ca48f09ed64b8fae13f9b46cc43ac5f580c4a605cbb0a517456c8f71"}, + {file = "scikit_learn-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:057b991ac64b3e75c9c04b5f9395eaf19a6179244c089afdebaad98264bff37c"}, + {file = "scikit_learn-1.5.0.tar.gz", hash = "sha256:789e3db01c750ed6d496fa2db7d50637857b451e57bcae863bff707c1247bef7"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.15.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.15.0)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.23)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + +[[package]] +name = "scipy" +version = "1.13.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca"}, + {file = "scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989"}, + {file = "scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f"}, + {file = "scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94"}, + {file = "scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9"}, + {file = "scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299"}, + {file = "scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa"}, + {file = "scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59"}, + {file = "scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1"}, + {file = "scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627"}, + {file = "scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884"}, + {file = "scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16"}, + {file = "scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5"}, + {file = "scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004"}, + {file = "scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d"}, + {file = "scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c"}, + {file = "scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2"}, + {file = "scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c"}, +] + +[package.dependencies] +numpy = ">=1.22.4,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + +[[package]] +name = "sentence-transformers" +version = "3.0.1" +description = "Multilingual text embeddings" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "sentence_transformers-3.0.1-py3-none-any.whl", hash = "sha256:01050cc4053c49b9f5b78f6980b5a72db3fd3a0abb9169b1792ac83875505ee6"}, + {file = "sentence_transformers-3.0.1.tar.gz", hash = "sha256:8a3d2c537cc4d1014ccc20ac92be3d6135420a3bc60ae29a3a8a9b4bb35fbff6"}, +] + +[package.dependencies] +huggingface-hub = ">=0.15.1" +numpy = "*" +Pillow = "*" +scikit-learn = "*" +scipy = "*" +torch = ">=1.11.0" +tqdm = "*" +transformers = ">=4.34.0,<5.0.0" + +[package.extras] +dev = ["accelerate (>=0.20.3)", "datasets", "pre-commit", "pytest", "ruff (>=0.3.0)"] +train = ["accelerate (>=0.20.3)", "datasets"] + [[package]] name = "sentencepiece" version = "0.2.0" @@ -2418,6 +2541,17 @@ files = [ {file = "tbb-2021.13.0-py3-none-win_amd64.whl", hash = "sha256:3528a53e4bbe64b07a6112b4c5a00ff3c61924ee46c9c68e004a1ac7ad1f09c3"}, ] +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "timm" version = "1.0.7" @@ -3127,4 +3261,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.9,<4" -content-hash = "acece33c26f04700b579c67617fb7ae874a4be7455921d0e1dcbd1012a8c8a62" +content-hash = "e1ab77837e695c15cc92f2f7674ebbf8441fa13a252e1a9cdaf8f79a773c7864" diff --git a/libs/simpleinference/pyproject.toml b/libs/simpleinference/pyproject.toml index dcc869f3..8c82f240 100644 --- a/libs/simpleinference/pyproject.toml +++ b/libs/simpleinference/pyproject.toml @@ -13,7 +13,7 @@ priority = "explicit" [tool.poetry.dependencies] python = ">=3.9,<4" # infinity_emb = {path = "../infinity_emb", extras = ["optimum","vision"]} -infinity_emb = {version = "0.0.49", extras = ["optimum","vision"]} +infinity_emb = {version = "0.0.49", extras = ["optimum","vision","torch"]} [tool.poetry.group.test.dependencies] pytest = "^7.0.0" coverage = {extras = ["toml"], version = "^7.3.2"} diff --git a/libs/simpleinference/simpleinference/infer.py b/libs/simpleinference/simpleinference/infer.py index 051d58ed..99f3e12f 100644 --- a/libs/simpleinference/simpleinference/infer.py +++ b/libs/simpleinference/simpleinference/infer.py @@ -1,7 +1,8 @@ from concurrent.futures import Future -from typing import Iterable, Literal, Union +from typing import Collection, Literal, Union from infinity_emb import EngineArgs, SyncEngineArray # type: ignore +from infinity_emb.infinity_server import AutoPadding __all__ = ["SimpleInference"] @@ -16,18 +17,17 @@ class SimpleInference: def __init__( self, *, - model_id: Union[ModelID, Iterable[ModelID]], - engine: Union[Engine, Iterable[Engine]] = "optimum", - device: Union[Device, Iterable[Device]] = "cpu", - embedding_dtype: Union[EmbeddingDtype, Iterable[EmbeddingDtype]] = "float32", + model_id: Union[ModelID, Collection[ModelID]], + engine: Union[Engine, Collection[Engine]] = "optimum", + device: Union[Device, Collection[Device]] = "cpu", + embedding_dtype: Union[EmbeddingDtype, Collection[EmbeddingDtype]] = "float32", ): """An easy interface to infer with multiple models. - >>> ei = SimpleInference(model_id="michaelfeil/bge-small-en-v1.5") + >>> ei = SimpleInference(model_id=['michaelfeil/bge-small-en-v1.5','mixedbread-ai/mxbai-rerank-xsmall-v1']) >>> ei - SimpleInference(['michaelfeil/bge-small-en-v1.5']) - >>> ei.stop() + SimpleInference(['michaelfeil/bge-small-en-v1.5', 'mixedbread-ai/mxbai-rerank-xsmall-v1']) + >>> ei.stop() # always stop when you are done """ - if isinstance(model_id, str): model_id = [model_id] if isinstance(engine, str): @@ -36,21 +36,21 @@ def __init__( device = [device] if isinstance(embedding_dtype, str): embedding_dtype = [embedding_dtype] - self._engine_args = [ - EngineArgs( - model_name_or_path=m, - engine=e, # type: ignore - device=d, # type: ignore - served_model_name=m, - embedding_dtype=edt, # type: ignore - lengths_via_tokenize=True, - model_warmup=False, - ) - for m, e, d, edt in zip(model_id, engine, device, embedding_dtype) - ] - self._engine_array = SyncEngineArray.from_args( - engine_args_array=self._engine_args + EngineArgs() + pad = AutoPadding( + length=len(model_id), + # pass through arguments + model_name_or_path=model_id, + engine=engine, + device=device, + embedding_dtype=embedding_dtype, + # optinionated defaults + lengths_via_tokenize=True, + model_warmup=True, + trust_remote_code=True, ) + self._engine_args = [EngineArgs(**kwargs) for kwargs in pad] + self._engine_array = SyncEngineArray.from_args(self._engine_args) def stop(self): self._engine_array.stop() @@ -66,14 +66,14 @@ def embed( ) -> Future[tuple[list[list[float]], int]]: """Embed sentences with a model. - >>> ei = SimpleInference(model_id="michaelfeil/bge-small-en-v1.5") + >>> ei = SimpleInference(model_id="michaelfeil/bge-small-en-v1.5", engine="torch") >>> embed_result = ei.embed(model_id="michaelfeil/bge-small-en-v1.5", sentences=["Hello, world!"]) >>> type(embed_result) >>> embed_result.result()[0][0].shape # embedding (384,) >>> embed_result.result()[1] # embedding and usage of 6 tokens - 6 + 4 >>> ei.stop() """ return self._engine_array.embed(model=model_id, sentences=sentences)