Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve progress bar new #440

Merged
merged 8 commits into from
Jan 27, 2025
Merged
134 changes: 125 additions & 9 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import time
import json
import shutil
import tarfile
from pathlib import Path
from typing import Any, Optional
from typing import Any

import requests
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, model_info, list_repo_tree
from huggingface_hub.hf_api import RepoFile
from huggingface_hub.utils import (
RepositoryNotFoundError,
disable_progress_bars,
Expand All @@ -17,6 +19,8 @@


class ModelManagement:
METADATA_FILE = "files_metadata.json"

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
"""Lists the supported models.
Expand Down Expand Up @@ -98,7 +102,7 @@ def download_files_from_huggingface(
cls,
hf_source_repo: str,
cache_dir: str,
extra_patterns: Optional[list[str]] = None,
extra_patterns: list[str],
local_files_only: bool = False,
**kwargs,
) -> str:
Expand All @@ -107,36 +111,148 @@ def download_files_from_huggingface(
Args:
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
cache_dir (Optional[str]): The path to the cache directory.
extra_patterns (Optional[list[str]]): extra patterns to allow in the snapshot download, typically
extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
includes the required model files.
local_files_only (bool, optional): Whether to only use local files. Defaults to False.
Returns:
Path: The path to the model directory.
"""

def _verify_files_from_metadata(
model_dir: Path, stored_metadata: dict[str, Any], repo_files: list[RepoFile]
) -> bool:
try:
for rel_path, meta in stored_metadata.items():
file_path = model_dir / rel_path

if not file_path.exists():
return False

if repo_files: # online verification
file_info = next((f for f in repo_files if f.path == file_path.name), None)
if (
not file_info
or file_info.size != meta["size"]
or file_info.blob_id != meta["blob_id"]
):
return False

else: # offline verification
if file_path.stat().st_size != meta["size"]:
return False
return True
except (OSError, KeyError) as e:
logger.error(f"Error verifying files: {str(e)}")
return False

def _collect_file_metadata(
model_dir: Path, repo_files: list[RepoFile]
) -> dict[str, dict[str, int]]:
meta = {}
file_info_map = {f.path: f for f in repo_files}
for file_path in model_dir.rglob("*"):
if file_path.is_file() and file_path.name != cls.METADATA_FILE:
repo_file = file_info_map.get(file_path.name)
if repo_file:
meta[str(file_path.relative_to(model_dir))] = {
"size": repo_file.size,
"blob_id": repo_file.blob_id,
}
return meta

def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, int]]) -> None:
try:
if not model_dir.exists():
model_dir.mkdir(parents=True, exist_ok=True)
(model_dir / cls.METADATA_FILE).write_text(json.dumps(meta))
except (OSError, ValueError) as e:
logger.warning(f"Error saving metadata: {str(e)}")

allow_patterns = [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"preprocessor_config.json",
]
if extra_patterns is not None:
allow_patterns.extend(extra_patterns)

allow_patterns.extend(extra_patterns)

snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}"
is_cached = snapshot_dir.exists()
metadata_file = snapshot_dir / cls.METADATA_FILE

if local_files_only:
disable_progress_bars()
if metadata_file.exists():
metadata = json.loads(metadata_file.read_text())
verified = _verify_files_from_metadata(snapshot_dir, metadata, repo_files=[])
if not verified:
logger.warning(
"Local file sizes do not match the metadata."
) # do not raise, still make an attempt to load the model
else:
logger.warning(
"Metadata file not found. Proceeding without checking local files."
) # if users have downloaded models from hf manually, or they're updating from previous versions of
# fastembed
result = snapshot_download(
repo_id=hf_source_repo,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
local_files_only=local_files_only,
**kwargs,
)
return result

repo_revision = model_info(hf_source_repo).sha
repo_tree = list(list_repo_tree(hf_source_repo, revision=repo_revision, repo_type="model"))

allowed_extensions = {".json", ".onnx", ".txt"}
repo_files = (
[
f
for f in repo_tree
if isinstance(f, RepoFile) and Path(f.path).suffix in allowed_extensions
]
if repo_tree
else []
)

verified_metadata = False

if snapshot_dir.exists() and metadata_file.exists():
metadata = json.loads(metadata_file.read_text())
verified_metadata = _verify_files_from_metadata(snapshot_dir, metadata, repo_files)

if is_cached:
if verified_metadata:
disable_progress_bars()

return snapshot_download(
result = snapshot_download(
repo_id=hf_source_repo,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
local_files_only=local_files_only,
**kwargs,
)

if (
not verified_metadata
): # metadata is not up-to-date, update it and check whether the files have been
# downloaded correctly
metadata = _collect_file_metadata(snapshot_dir, repo_files)

download_successful = _verify_files_from_metadata(
snapshot_dir, metadata, repo_files=[]
) # offline verification
if not download_successful:
raise ValueError(
"Files have been corrupted during downloading process. "
"Please check your internet connection and try again."
)
_save_file_metadata(snapshot_dir, metadata)

return result

@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
"""
Expand Down
Loading