diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 33f513c8..d85406f0 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -1,12 +1,14 @@ import os import time +import json import shutil import tarfile from pathlib import Path -from typing import Any, Optional +from typing import Any import requests -from huggingface_hub import snapshot_download +from huggingface_hub import snapshot_download, model_info, list_repo_tree +from huggingface_hub.hf_api import RepoFile from huggingface_hub.utils import ( RepositoryNotFoundError, disable_progress_bars, @@ -17,6 +19,8 @@ class ModelManagement: + METADATA_FILE = "files_metadata.json" + @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: """Lists the supported models. @@ -98,7 +102,7 @@ def download_files_from_huggingface( cls, hf_source_repo: str, cache_dir: str, - extra_patterns: Optional[list[str]] = None, + extra_patterns: list[str], local_files_only: bool = False, **kwargs, ) -> str: @@ -107,12 +111,63 @@ def download_files_from_huggingface( Args: hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx". cache_dir (Optional[str]): The path to the cache directory. - extra_patterns (Optional[list[str]]): extra patterns to allow in the snapshot download, typically + extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically includes the required model files. local_files_only (bool, optional): Whether to only use local files. Defaults to False. Returns: Path: The path to the model directory. """ + + def _verify_files_from_metadata( + model_dir: Path, stored_metadata: dict[str, Any], repo_files: list[RepoFile] + ) -> bool: + try: + for rel_path, meta in stored_metadata.items(): + file_path = model_dir / rel_path + + if not file_path.exists(): + return False + + if repo_files: # online verification + file_info = next((f for f in repo_files if f.path == file_path.name), None) + if ( + not file_info + or file_info.size != meta["size"] + or file_info.blob_id != meta["blob_id"] + ): + return False + + else: # offline verification + if file_path.stat().st_size != meta["size"]: + return False + return True + except (OSError, KeyError) as e: + logger.error(f"Error verifying files: {str(e)}") + return False + + def _collect_file_metadata( + model_dir: Path, repo_files: list[RepoFile] + ) -> dict[str, dict[str, int]]: + meta = {} + file_info_map = {f.path: f for f in repo_files} + for file_path in model_dir.rglob("*"): + if file_path.is_file() and file_path.name != cls.METADATA_FILE: + repo_file = file_info_map.get(file_path.name) + if repo_file: + meta[str(file_path.relative_to(model_dir))] = { + "size": repo_file.size, + "blob_id": repo_file.blob_id, + } + return meta + + def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, int]]) -> None: + try: + if not model_dir.exists(): + model_dir.mkdir(parents=True, exist_ok=True) + (model_dir / cls.METADATA_FILE).write_text(json.dumps(meta)) + except (OSError, ValueError) as e: + logger.warning(f"Error saving metadata: {str(e)}") + allow_patterns = [ "config.json", "tokenizer.json", @@ -120,16 +175,59 @@ def download_files_from_huggingface( "special_tokens_map.json", "preprocessor_config.json", ] - if extra_patterns is not None: - allow_patterns.extend(extra_patterns) + + allow_patterns.extend(extra_patterns) snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}" - is_cached = snapshot_dir.exists() + metadata_file = snapshot_dir / cls.METADATA_FILE + + if local_files_only: + disable_progress_bars() + if metadata_file.exists(): + metadata = json.loads(metadata_file.read_text()) + verified = _verify_files_from_metadata(snapshot_dir, metadata, repo_files=[]) + if not verified: + logger.warning( + "Local file sizes do not match the metadata." + ) # do not raise, still make an attempt to load the model + else: + logger.warning( + "Metadata file not found. Proceeding without checking local files." + ) # if users have downloaded models from hf manually, or they're updating from previous versions of + # fastembed + result = snapshot_download( + repo_id=hf_source_repo, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + local_files_only=local_files_only, + **kwargs, + ) + return result + + repo_revision = model_info(hf_source_repo).sha + repo_tree = list(list_repo_tree(hf_source_repo, revision=repo_revision, repo_type="model")) + + allowed_extensions = {".json", ".onnx", ".txt"} + repo_files = ( + [ + f + for f in repo_tree + if isinstance(f, RepoFile) and Path(f.path).suffix in allowed_extensions + ] + if repo_tree + else [] + ) + + verified_metadata = False + + if snapshot_dir.exists() and metadata_file.exists(): + metadata = json.loads(metadata_file.read_text()) + verified_metadata = _verify_files_from_metadata(snapshot_dir, metadata, repo_files) - if is_cached: + if verified_metadata: disable_progress_bars() - return snapshot_download( + result = snapshot_download( repo_id=hf_source_repo, allow_patterns=allow_patterns, cache_dir=cache_dir, @@ -137,6 +235,24 @@ def download_files_from_huggingface( **kwargs, ) + if ( + not verified_metadata + ): # metadata is not up-to-date, update it and check whether the files have been + # downloaded correctly + metadata = _collect_file_metadata(snapshot_dir, repo_files) + + download_successful = _verify_files_from_metadata( + snapshot_dir, metadata, repo_files=[] + ) # offline verification + if not download_successful: + raise ValueError( + "Files have been corrupted during downloading process. " + "Please check your internet connection and try again." + ) + _save_file_metadata(snapshot_dir, metadata) + + return result + @classmethod def decompress_to_cache(cls, targz_path: str, cache_dir: str): """