Skip to content

Commit

Permalink
Enable HF Transfer in tests (#51)
Browse files Browse the repository at this point in the history
* Enable `hf-transfer`

* Add a flag for the user-agent in ci

* Add --fix for ruff in fix-quality make rule
  • Loading branch information
mfuntowicz authored Dec 16, 2023
1 parent 7ddf18c commit 9f65740
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pr_slow_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ env:
RUN_CPU_ONLY: OFF
RUN_NIGHTLY: OFF
RUN_SLOW: ON
HF_HUB_ENABLE_HF_TRANSFER: ON

jobs:
run_gpu_tests:
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fix-quality:
python3 -m ruff check examples scripts src tests
python3 -m ruff check examples scripts src tests --fix
python3 -m ruff format examples scripts src tests

quality:
Expand Down
12 changes: 11 additions & 1 deletion src/optimum/nvidia/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..version import __version__
from .nvml import get_device_compute_capabilities
from .tests.utils import parse_flag_from_env


USER_AGENT_BASE = [f"optimum/nvidia/{__version__}", f"python/{pyversion.split()[0]}"]
Expand All @@ -18,6 +19,7 @@ def get_user_agent() -> str:
"""
ua = USER_AGENT_BASE.copy()

# Nvidia driver / devices
try:
nvmlInit()
ua.append(f"nvidia/{nvmlSystemGetDriverVersion()}")
Expand All @@ -36,6 +38,7 @@ def get_user_agent() -> str:
except (RuntimeError, ImportError):
ua.append("nvidia/unknown")

# Torch / CUDA related version, (from torch)
try:
from torch import __version__ as pt_version
from torch.version import cuda, cudnn
Expand All @@ -46,19 +49,26 @@ def get_user_agent() -> str:
except ImportError:
pass

# TODO: Refactor later on
# transformers version
try:
from transformers import __version__ as tfrs_version

ua.append(f"transformers/{tfrs_version}")
except ImportError:
pass

# TRTLLM version
try:
from tensorrt_llm._utils import trt_version

ua.append(f"tensorrt/{trt_version()}")
except ImportError:
pass

# Add a flag for CI
if parse_flag_from_env("OPTIMUM_NVIDIA_IS_CI", False):
ua.append("is_ci/true")
else:
ua.append("is_ci/false")

return "; ".join(ua)

0 comments on commit 9f65740

Please sign in to comment.