Skip to content

Commit

Permalink
Merge pull request #72 from michaelfeil/release-for-localfiles
Browse files Browse the repository at this point in the history
update optimum_utils
  • Loading branch information
michaelfeil authored Jan 19, 2024
2 parents bdc144a + 2537996 commit e30c584
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions libs/infinity_emb/infinity_emb/transformer/utils_optimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

try:
from huggingface_hub import HfApi, HfFolder # type: ignore
from huggingface_hub.constants import HF_HUB_CACHE # type: ignore
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE # type: ignore
from optimum.onnxruntime import ORTOptimizer # type: ignore
from optimum.onnxruntime.configuration import OptimizationConfig # type: ignore
except ImportError:
Expand Down Expand Up @@ -42,7 +42,7 @@ def optimize_model(
path_folder = (
Path(model_name_or_path)
if Path(model_name_or_path).exists()
else Path(HF_HUB_CACHE) / "infinity_onnx" / model_name_or_path
else Path(HUGGINGFACE_HUB_CACHE) / "infinity_onnx" / model_name_or_path
)
files_optimized = list(path_folder.glob("**/*optimized.onnx"))
if files_optimized and not execution_provider == "TensorrtExecutionProvider":
Expand Down Expand Up @@ -93,19 +93,27 @@ def optimize_model(


def get_onnx_files(
model_id: str,
model_name_or_path: str,
revision: str,
use_auth_token: Union[bool, str] = True,
prefer_quantized=False,
) -> Path:
"""gets the onnx files from the repo"""
if isinstance(use_auth_token, bool):
token = HfFolder().get_token()
if not Path(model_name_or_path).exists():
if isinstance(use_auth_token, bool):
token = HfFolder().get_token()
else:
token = use_auth_token
repo_files = list(
map(
Path,
HfApi().list_repo_files(
model_name_or_path, revision=revision, token=token
),
)
)
else:
token = use_auth_token
repo_files = map(
Path, HfApi().list_repo_files(model_id, revision=revision, token=token)
)
repo_files = list(Path(model_name_or_path).glob("**/*"))
pattern = "**.onnx"
onnx_files = [p for p in repo_files if p.match(pattern)]

Expand All @@ -121,4 +129,4 @@ def get_onnx_files(
elif len(onnx_files) == 1:
return onnx_files[0]
else:
raise ValueError(f"No onnx files found for {model_id}")
raise ValueError(f"No onnx files found for {model_name_or_path}")

0 comments on commit e30c584

Please sign in to comment.