Skip to content

Commit

Permalink
Improve progress bar new (#440)
Browse files Browse the repository at this point in the history
* improve: Improve progress bar

* fix: Fix error downloading when internet connection down

* new: Added file hash computation to track new versions

* refactor: Removed redundant hash check
fix: Fix ci

* new: Verify using hf_api

* new: Improve progress bar

* refactor new progress bar (#446)

* refactor

* chore: Remove redundant enable progress bar

---------

Co-authored-by: hh-space-invader <[email protected]>

* refactor comments

---------

Co-authored-by: George <[email protected]>
hh-space-invader and joein authored Jan 27, 2025

Verified

This commit was signed with the committer’s verified signature.
calcastor BT (calcastor/mame)
1 parent ae37da3 commit 54f6cd9
Showing 1 changed file with 125 additions and 9 deletions.
134 changes: 125 additions & 9 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import time
import json
import shutil
import tarfile
from pathlib import Path
from typing import Any, Optional
from typing import Any

import requests
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, model_info, list_repo_tree
from huggingface_hub.hf_api import RepoFile
from huggingface_hub.utils import (
RepositoryNotFoundError,
disable_progress_bars,
@@ -17,6 +19,8 @@


class ModelManagement:
METADATA_FILE = "files_metadata.json"

@classmethod
def list_supported_models(cls) -> list[dict[str, Any]]:
"""Lists the supported models.
@@ -98,7 +102,7 @@ def download_files_from_huggingface(
cls,
hf_source_repo: str,
cache_dir: str,
extra_patterns: Optional[list[str]] = None,
extra_patterns: list[str],
local_files_only: bool = False,
**kwargs,
) -> str:
@@ -107,36 +111,148 @@ def download_files_from_huggingface(
Args:
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
cache_dir (Optional[str]): The path to the cache directory.
extra_patterns (Optional[list[str]]): extra patterns to allow in the snapshot download, typically
extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
includes the required model files.
local_files_only (bool, optional): Whether to only use local files. Defaults to False.
Returns:
Path: The path to the model directory.
"""

def _verify_files_from_metadata(
model_dir: Path, stored_metadata: dict[str, Any], repo_files: list[RepoFile]
) -> bool:
try:
for rel_path, meta in stored_metadata.items():
file_path = model_dir / rel_path

if not file_path.exists():
return False

if repo_files: # online verification
file_info = next((f for f in repo_files if f.path == file_path.name), None)
if (
not file_info
or file_info.size != meta["size"]
or file_info.blob_id != meta["blob_id"]
):
return False

else: # offline verification
if file_path.stat().st_size != meta["size"]:
return False
return True
except (OSError, KeyError) as e:
logger.error(f"Error verifying files: {str(e)}")
return False

def _collect_file_metadata(
model_dir: Path, repo_files: list[RepoFile]
) -> dict[str, dict[str, int]]:
meta = {}
file_info_map = {f.path: f for f in repo_files}
for file_path in model_dir.rglob("*"):
if file_path.is_file() and file_path.name != cls.METADATA_FILE:
repo_file = file_info_map.get(file_path.name)
if repo_file:
meta[str(file_path.relative_to(model_dir))] = {
"size": repo_file.size,
"blob_id": repo_file.blob_id,
}
return meta

def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, int]]) -> None:
try:
if not model_dir.exists():
model_dir.mkdir(parents=True, exist_ok=True)
(model_dir / cls.METADATA_FILE).write_text(json.dumps(meta))
except (OSError, ValueError) as e:
logger.warning(f"Error saving metadata: {str(e)}")

allow_patterns = [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"preprocessor_config.json",
]
if extra_patterns is not None:
allow_patterns.extend(extra_patterns)

allow_patterns.extend(extra_patterns)

snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}"
is_cached = snapshot_dir.exists()
metadata_file = snapshot_dir / cls.METADATA_FILE

if local_files_only:
disable_progress_bars()
if metadata_file.exists():
metadata = json.loads(metadata_file.read_text())
verified = _verify_files_from_metadata(snapshot_dir, metadata, repo_files=[])
if not verified:
logger.warning(
"Local file sizes do not match the metadata."
) # do not raise, still make an attempt to load the model
else:
logger.warning(
"Metadata file not found. Proceeding without checking local files."
) # if users have downloaded models from hf manually, or they're updating from previous versions of
# fastembed
result = snapshot_download(
repo_id=hf_source_repo,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
local_files_only=local_files_only,
**kwargs,
)
return result

repo_revision = model_info(hf_source_repo).sha
repo_tree = list(list_repo_tree(hf_source_repo, revision=repo_revision, repo_type="model"))

allowed_extensions = {".json", ".onnx", ".txt"}
repo_files = (
[
f
for f in repo_tree
if isinstance(f, RepoFile) and Path(f.path).suffix in allowed_extensions
]
if repo_tree
else []
)

verified_metadata = False

if snapshot_dir.exists() and metadata_file.exists():
metadata = json.loads(metadata_file.read_text())
verified_metadata = _verify_files_from_metadata(snapshot_dir, metadata, repo_files)

if is_cached:
if verified_metadata:
disable_progress_bars()

return snapshot_download(
result = snapshot_download(
repo_id=hf_source_repo,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
local_files_only=local_files_only,
**kwargs,
)

if (
not verified_metadata
): # metadata is not up-to-date, update it and check whether the files have been
# downloaded correctly
metadata = _collect_file_metadata(snapshot_dir, repo_files)

download_successful = _verify_files_from_metadata(
snapshot_dir, metadata, repo_files=[]
) # offline verification
if not download_successful:
raise ValueError(
"Files have been corrupted during downloading process. "
"Please check your internet connection and try again."
)
_save_file_metadata(snapshot_dir, metadata)

return result

@classmethod
def decompress_to_cache(cls, targz_path: str, cache_dir: str):
"""

0 comments on commit 54f6cd9

Please sign in to comment.