Skip to content

Commit

Permalink
* fix(embedding.py): fix model_name splitting to correctly extract th…
Browse files Browse the repository at this point in the history
…e model name

* fix(embedding.py): handle PermissionError when downloading fast_model_name.tar.gz and try simple_model_name.tar.gz as a fallback
* fix(embedding.py): raise ValueError if neither fast_model_name.tar.gz nor simple are valid
  • Loading branch information
NirantK committed Sep 25, 2023
1 parent 464b12a commit 6810141
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,27 @@ def retrieve_model(self, model_name: str, cache_dir: str) -> Path:

assert "/" in model_name, "model_name must be in the format <org>/<model> e.g. BAAI/bge-base-en"

model_name = model_name.split("/")[-1]

fast_model_name = f"fast-{model_name}"
fast_model_name = f"fast-{model_name.split('/')[-1]}"

model_dir = Path(cache_dir) / fast_model_name
if model_dir.exists():
return model_dir

model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
self.download_file_from_gcs(
f"https://storage.googleapis.com/qdrant-fastembed/{fast_model_name}.tar.gz",
try:
self.download_file_from_gcs(
f"https://storage.googleapis.com/qdrant-fastembed/{fast_model_name}.tar.gz",
output_path=str(model_tar_gz),
)
)
except PermissionError:
simple_model_name = model_name.replace("/", "-")
print(f"Was not able to download {fast_model_name}.tar.gz, trying {simple_model_name}.tar.gz")
self.download_file_from_gcs(
f"https://storage.googleapis.com/qdrant-fastembed/{simple_model_name}.tar.gz",
output_path=str(model_tar_gz),
)
else:
raise ValueError(f"Could not find {model_tar_gz}")

self.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=cache_dir)
assert model_dir.exists(), f"Could not find {model_dir} in {cache_dir}"
Expand Down

0 comments on commit 6810141

Please sign in to comment.