Skip to content

Commit

Permalink
feat: Added support for FASTEMBED_CACHE_PATH env var (#68)
Browse files Browse the repository at this point in the history
* chore: FASTEMBED_CACHE_PATH env

* chore: temp directory fallback

* chore: tempdir fallback JinaEmbedding
  • Loading branch information
Anush008 authored Nov 22, 2023
1 parent f222d7c commit 0a94425
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import shutil
import tarfile
import tempfile

from abc import ABC, abstractmethod
from itertools import islice
from multiprocessing import get_all_start_methods
Expand Down Expand Up @@ -459,7 +461,9 @@ def __init__(
Args:
model_name (str): The name of the model to use.
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
cache_dir (str, optional): The path to the cache directory. Defaults to `local_cache` in the current directory.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
Raises:
Expand All @@ -468,7 +472,8 @@ def __init__(
self.model_name = model_name

if cache_dir is None:
cache_dir = Path(".").resolve() / "local_cache"
default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache")
cache_dir = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir))
cache_dir.mkdir(parents=True, exist_ok=True)

self._cache_dir = cache_dir
Expand Down Expand Up @@ -576,15 +581,18 @@ def __init__(
Args:
model_name (str): The name of the model to use.
max_length (int, optional): The maximum number of tokens. Defaults to 512. Unknown behavior for values > 512.
cache_dir (str, optional): The path to the cache directory. Defaults to `local_cache` in the current directory.
cache_dir (str, optional): The path to the cache directory.
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
Defaults to `fastembed_cache` in the system's temp directory.
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
"""
self.model_name = model_name

if cache_dir is None:
cache_dir = Path(".").resolve() / "local_cache"
default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache")
cache_dir = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir))
cache_dir.mkdir(parents=True, exist_ok=True)

self._cache_dir = cache_dir
Expand Down

0 comments on commit 0a94425

Please sign in to comment.