diff --git a/fastembed/embedding.py b/fastembed/embedding.py index 52543b55..f642c65e 100644 --- a/fastembed/embedding.py +++ b/fastembed/embedding.py @@ -2,6 +2,8 @@ import os import shutil import tarfile +import tempfile + from abc import ABC, abstractmethod from itertools import islice from multiprocessing import get_all_start_methods @@ -459,7 +461,9 @@ def __init__( Args: model_name (str): The name of the model to use. max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512. - cache_dir (str, optional): The path to the cache directory. Defaults to `local_cache` in the current directory. + cache_dir (str, optional): The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. Raises: @@ -468,7 +472,8 @@ def __init__( self.model_name = model_name if cache_dir is None: - cache_dir = Path(".").resolve() / "local_cache" + default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache") + cache_dir = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir)) cache_dir.mkdir(parents=True, exist_ok=True) self._cache_dir = cache_dir @@ -576,7 +581,9 @@ def __init__( Args: model_name (str): The name of the model to use. max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512. - cache_dir (str, optional): The path to the cache directory. Defaults to `local_cache` in the current directory. + cache_dir (str, optional): The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None. Raises: ValueError: If the model_name is not in the format / e.g. BAAI/bge-base-en. @@ -584,7 +591,8 @@ def __init__( self.model_name = model_name if cache_dir is None: - cache_dir = Path(".").resolve() / "local_cache" + default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache") + cache_dir = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir)) cache_dir.mkdir(parents=True, exist_ok=True) self._cache_dir = cache_dir