diff --git a/.github/workflows/test_cli_cpu_py_txi.yaml b/.github/workflows/test_cli_cpu_py_txi.yaml index 7b1946e7..06bd841d 100644 --- a/.github/workflows/test_cli_cpu_py_txi.yaml +++ b/.github/workflows/test_cli_cpu_py_txi.yaml @@ -43,9 +43,12 @@ jobs: - name: Install requirements run: | - pip install --upgrade pip - pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install -e .[testing,py-txi] git+https://github.com/IlyasMoutawwakil/py-txi.git + pip install uv + uv pip install --upgrade pip + uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + uv pip install -e .[testing,py-txi] git+https://github.com/IlyasMoutawwakil/py-txi.git + env: + UV_SYSTEM_PYTHON: 1 - name: Run tests run: pytest tests/test_cli.py -s -k "cli and cpu and py_txi" diff --git a/.github/workflows/test_cli_cuda_py_txi.yaml b/.github/workflows/test_cli_cuda_py_txi.yaml index 5c090b28..a7fe9a51 100644 --- a/.github/workflows/test_cli_cuda_py_txi.yaml +++ b/.github/workflows/test_cli_cuda_py_txi.yaml @@ -44,11 +44,16 @@ jobs: - name: Install requirements run: | - pip install --upgrade pip - pip install -e .[testing,py-txi] git+https://github.com/IlyasMoutawwakil/py-txi.git + pip install uv + uv pip install --upgrade pip + uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + uv pip install -e .[testing,py-txi] git+https://github.com/IlyasMoutawwakil/py-txi.git + env: + UV_SYSTEM_PYTHON: 1 - name: Run tests - run: pytest tests/test_cli.py -x -s -k "cli and cuda and py_txi" + run: | + FORCE_SEQUENTIAL=1 pytest tests/test_cli.py -x -s -k "cli and cuda and py_txi" - if: ${{ (github.event_name == 'push') || @@ -56,4 +61,5 @@ jobs: contains( github.event.pull_request.labels.*.name, 'examples') }} name: Run examples - run: pytest tests/test_examples.py -x -s -k "cli and cuda and (tgi or tei)" + run: | + FORCE_SEQUENTIAL=1 pytest tests/test_examples.py -x -s -k "cli and cuda and (tgi or tei)" diff --git a/README.md b/README.md index 6358b341..9203b778 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,6 @@ Optimum-Benchmark is continuously and intensively tested on a variety of devices [![CLI_CPU_IPEX](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_ipex.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_ipex.yaml) [![CLI_CPU_LLAMA_CPP](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_llama_cpp.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_llama_cpp.yaml) -[![CLI_CPU_NEURAL_COMPRESSOR](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_neural_compressor.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_neural_compressor.yaml) [![CLI_CPU_ONNXRUNTIME](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_onnxruntime.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_onnxruntime.yaml) [![CLI_CPU_OPENVINO](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_openvino.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_openvino.yaml) [![CLI_CPU_PYTORCH](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_pytorch.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cpu_pytorch.yaml) @@ -61,7 +60,6 @@ Optimum-Benchmark is continuously and intensively tested on a variety of devices [![CLI_CUDA_TENSORRT_LLM](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_tensorrt_llm.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_tensorrt_llm.yaml) [![CLI_CUDA_TORCH_ORT](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_torch_ort.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_torch_ort.yaml) [![CLI_CUDA_VLLM](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_vllm.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_vllm.yaml) -[![CLI_ENERGY_STAR](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_energy_star.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_energy_star.yaml) [![CLI_MISC](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_misc.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_misc.yaml) [![CLI_ROCM_PYTORCH](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_rocm_pytorch.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_rocm_pytorch.yaml) @@ -100,10 +98,9 @@ Depending on the backends you want to use, you can install `optimum-benchmark` w - OnnxRuntime: `pip install optimum-benchmark[onnxruntime]` - TensorRT-LLM: `pip install optimum-benchmark[tensorrt-llm]` - OnnxRuntime-GPU: `pip install optimum-benchmark[onnxruntime-gpu]` -- Neural Compressor: `pip install optimum-benchmark[neural-compressor]` -- Py-TXI: `pip install optimum-benchmark[py-txi]` -- IPEX: `pip install optimum-benchmark[ipex]` +- Py-TXI (TGI & TEI): `pip install optimum-benchmark[py-txi]` - vLLM: `pip install optimum-benchmark[vllm]` +- IPEX: `pip install optimum-benchmark[ipex]` We also support the following extra extra dependencies: @@ -144,9 +141,6 @@ if __name__ == "__main__": ) benchmark_report = Benchmark.launch(benchmark_config) - # log the benchmark in terminal - benchmark_report.log() # or print(benchmark_report) - # convert artifacts to a dictionary or dataframe benchmark_config.to_dict() # or benchmark_config.to_dataframe() @@ -175,15 +169,17 @@ If you're on VSCode, you can hover over the configuration classes to see the ava You can also run a benchmark using the command line by specifying the configuration directory and the configuration name. Both arguments are mandatory for [`hydra`](https://hydra.cc/). `--config-dir` is the directory where the configuration files are stored and `--config-name` is the name of the configuration file without its `.yaml` extension. ```bash -optimum-benchmark --config-dir examples/ --config-name pytorch_bert +optimum-benchmark --config-dir examples/ --config-name cuda_pytorch_bert ``` -This will run the benchmark using the configuration in [`examples/pytorch_bert.yaml`](examples/pytorch_bert.yaml) and store the results in `runs/pytorch_bert`. +This will run the benchmark using the configuration in [`examples/cuda_pytorch_bert.yaml`](examples/cuda_pytorch_bert.yaml) and store the results in `runs/cuda_pytorch_bert`. The resulting files are : - `benchmark_config.json` which contains the configuration used for the benchmark, including the backend, launcher, scenario and the environment in which the benchmark was run. - `benchmark_report.json` which contains a full report of the benchmark's results, like latency measurements, memory usage, energy consumption, etc. +- `benchmark_report.txt` which contains a detailed report of the benchmark's results, in the same format they were logged. +- `benchmark_report.md` which contains a detailed report of the benchmark's results, in markdown format. - `benchmark.json` contains both the report and the configuration in a single file. - `benchmark.log` contains the logs of the benchmark run. @@ -309,9 +305,7 @@ For more information on the features of each backend, you can check their respec - [PyTorchConfig](optimum_benchmark/backends/pytorch/config.py) - [ORTConfig](optimum_benchmark/backends/onnxruntime/config.py) - [TorchORTConfig](optimum_benchmark/backends/torch_ort/config.py) -- [LLMSwarmConfig](optimum_benchmark/backends/llm_swarm/config.py) - [TRTLLMConfig](optimum_benchmark/backends/tensorrt_llm/config.py) -- [INCConfig](optimum_benchmark/backends/neural_compressor/config.py) diff --git a/examples/cuda_pytorch_bert.yaml b/examples/cuda_pytorch_bert.yaml index 8ab9b5cb..195e8a02 100644 --- a/examples/cuda_pytorch_bert.yaml +++ b/examples/cuda_pytorch_bert.yaml @@ -6,7 +6,7 @@ defaults: - _base_ - _self_ -name: pytorch_bert +name: cuda_pytorch_bert launcher: device_isolation: true diff --git a/optimum_benchmark/backends/config.py b/optimum_benchmark/backends/config.py index fc265d4d..c47b7366 100644 --- a/optimum_benchmark/backends/config.py +++ b/optimum_benchmark/backends/config.py @@ -22,13 +22,13 @@ class BackendConfig(ABC): version: str _target_: str + model: Optional[str] = None + processor: Optional[str] = None + task: Optional[str] = None library: Optional[str] = None model_type: Optional[str] = None - model: Optional[str] = None - processor: Optional[str] = None - device: Optional[str] = None # we use a string here instead of a list # because it's easier to pass in a yaml or from cli @@ -48,30 +48,44 @@ def __post_init__(self): if self.model is None: raise ValueError("`model` must be specified.") + if self.model_kwargs.get("token", None) is not None: + LOGGER.info( + "You have passed an argument `token` to `model_kwargs`. This is dangerous as the config cannot do encryption to protect it. " + "We will proceed to registering `token` in the environment as `HF_TOKEN` to avoid saving it or pushing it to the hub by mistake." + ) + os.environ["HF_TOKEN"] = self.model_kwargs.pop("token") + if self.processor is None: self.processor = self.model - # TODO: add cache_dir, token, etc. to these methods + if not self.processor_kwargs: + self.processor_kwargs = self.model_kwargs + if self.library is None: self.library = infer_library_from_model_name_or_path( model_name_or_path=self.model, - token=self.model_kwargs.get("token", None), revision=self.model_kwargs.get("revision", None), + cache_dir=self.model_kwargs.get("cache_dir", None), + ) + + if self.library not in ["transformers", "diffusers", "timm", "llama_cpp"]: + raise ValueError( + f"`library` must be either `transformers`, `diffusers`, `timm` or `llama_cpp`, but got {self.library}" ) if self.task is None: self.task = infer_task_from_model_name_or_path( model_name_or_path=self.model, - token=self.model_kwargs.get("token", None), revision=self.model_kwargs.get("revision", None), + cache_dir=self.model_kwargs.get("cache_dir", None), library_name=self.library, ) if self.model_type is None: self.model_type = infer_model_type_from_model_name_or_path( model_name_or_path=self.model, - token=self.model_kwargs.get("token", None), revision=self.model_kwargs.get("revision", None), + cache_dir=self.model_kwargs.get("cache_dir", None), library_name=self.library, ) @@ -103,11 +117,6 @@ def __post_init__(self): else: raise RuntimeError("CUDA device is only supported on systems with NVIDIA or ROCm drivers.") - if self.library not in ["transformers", "diffusers", "timm", "llama_cpp"]: - raise ValueError( - f"`library` must be either `transformers`, `diffusers`, `timm` or `llama_cpp`, but got {self.library}" - ) - if self.inter_op_num_threads is not None: if self.inter_op_num_threads == -1: self.inter_op_num_threads = cpu_count() diff --git a/optimum_benchmark/backends/py_txi/backend.py b/optimum_benchmark/backends/py_txi/backend.py index 6e637a31..55aecab9 100644 --- a/optimum_benchmark/backends/py_txi/backend.py +++ b/optimum_benchmark/backends/py_txi/backend.py @@ -1,11 +1,12 @@ -import os +import shutil +from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import torch -from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download from py_txi import TEI, TGI, TEIConfig, TGIConfig -from safetensors.torch import save_file +from safetensors.torch import save_model from ...task_utils import TEXT_EMBEDDING_TASKS, TEXT_GENERATION_TASKS from ..base import Backend @@ -15,6 +16,7 @@ class PyTXIBackend(Backend[PyTXIConfig]): NAME: str = "py-txi" + pretrained_model: Union[TEI, TGI] def __init__(self, config: PyTXIConfig) -> None: super().__init__(config) @@ -31,114 +33,141 @@ def load(self) -> None: else: self.logger.info("\t+ Downloading pretrained model") self.download_pretrained_model() - - if self.config.task in TEXT_GENERATION_TASKS: - self.logger.info("\t+ Preparing generation config") - self.prepare_generation_config() - self.logger.info("\t+ Loading pretrained model") self.load_model_from_pretrained() - self.tmpdir.cleanup() - - @property - def volume(self) -> str: - return list(self.config.volumes.keys())[0] + try: + self.tmpdir.cleanup() + except Exception: + shutil.rmtree(self.tmpdir.name, ignore_errors=True) def download_pretrained_model(self) -> None: - # directly downloads pretrained model in volume (/data) to change generation config before loading model - with init_empty_weights(include_buffers=True): - self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs, cache_dir=self.volume) - - def prepare_generation_config(self) -> None: - self.generation_config.eos_token_id = None - self.generation_config.pad_token_id = None - - model_cache_folder = f"models/{self.config.model}".replace("/", "--") - model_cache_path = f"{self.volume}/{model_cache_folder}" - snapshot_file = f"{model_cache_path}/refs/{self.config.model_kwargs.get('revision', 'main')}" - snapshot_ref = open(snapshot_file, "r").read().strip() - model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}" - self.logger.info("\t+ Saving new pretrained generation config") - self.generation_config.save_pretrained(save_directory=model_snapshot_path) + model_snapshot_folder = snapshot_download(self.config.model, **self.config.model_kwargs) + + if self.config.task in TEXT_GENERATION_TASKS: + self.generation_config.eos_token_id = None + self.generation_config.pad_token_id = None + self.generation_config.save_pretrained(save_directory=model_snapshot_folder) def create_no_weights_model(self) -> None: - self.no_weights_model = os.path.join(self.tmpdir.name, "no_weights_model") - self.logger.info("\t+ Creating no weights model directory") - os.makedirs(self.no_weights_model, exist_ok=True) - self.logger.info("\t+ Creating no weights model state dict") - state_dict = torch.nn.Linear(1, 1).state_dict() - self.logger.info("\t+ Saving no weights model safetensors") - safetensor = os.path.join(self.no_weights_model, "model.safetensors") - save_file(tensors=state_dict, filename=safetensor, metadata={"format": "pt"}) - self.logger.info("\t+ Saving no weights model pretrained config") - self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) - self.logger.info("\t+ Saving no weights model pretrained processor") - self.pretrained_processor.save_pretrained(save_directory=self.no_weights_model) - # unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model - self.logger.info(f"\t+ Loading no weights model from {self.no_weights_model}") + model_path = Path(hf_hub_download(self.config.model, filename="config.json", cache_dir=self.tmpdir.name)).parent + save_model(model=torch.nn.Linear(1, 1), filename=model_path / "model.safetensors", metadata={"format": "pt"}) + + self.pretrained_processor.save_pretrained(save_directory=model_path) + self.pretrained_config.save_pretrained(save_directory=model_path) + with fast_weights_init(): + # unlike Transformers, TXI won't accept any missing tensors so we need to materialize the model self.pretrained_model = self.automodel_loader.from_pretrained( - self.no_weights_model, **self.config.model_kwargs, device_map="auto", _fast_init=False + model_path, + _fast_init=False, + device_map="auto", + **self.config.model_kwargs, ) - self.logger.info("\t+ Saving no weights model") - self.pretrained_model.save_pretrained(save_directory=self.no_weights_model) + + save_model(model=self.pretrained_model, filename=model_path / "model.safetensors", metadata={"format": "pt"}) del self.pretrained_model torch.cuda.empty_cache() if self.config.task in TEXT_GENERATION_TASKS: - self.logger.info("\t+ Modifying generation config for fixed length generation") self.generation_config.eos_token_id = None self.generation_config.pad_token_id = None - self.logger.info("\t+ Saving new pretrained generation config") - self.generation_config.save_pretrained(save_directory=self.no_weights_model) + self.generation_config.save_pretrained(save_directory=model_path) def load_model_with_no_weights(self) -> None: - original_volumes, self.config.volumes = self.config.volumes, {self.tmpdir.name: {"bind": "/data", "mode": "rw"}} - original_model, self.config.model = self.config.model, "/data/no_weights_model" - self.logger.info("\t+ Loading no weights model") + self.config.volumes = {self.tmpdir.name: {"bind": "/data", "mode": "rw"}} self.load_model_from_pretrained() - self.config.model, self.config.volumes = original_model, original_volumes def load_model_from_pretrained(self) -> None: if self.config.task in TEXT_GENERATION_TASKS: self.pretrained_model = TGI( - config=TGIConfig( - model_id=self.config.model, - gpus=self.config.gpus, - devices=self.config.devices, - volumes=self.config.volumes, - environment=self.config.environment, - ports=self.config.ports, - dtype=self.config.dtype, - sharded=self.config.sharded, - quantize=self.config.quantize, - num_shard=self.config.num_shard, - speculate=self.config.speculate, - cuda_graphs=self.config.cuda_graphs, - disable_custom_kernels=self.config.disable_custom_kernels, - trust_remote_code=self.config.trust_remote_code, - max_concurrent_requests=self.config.max_concurrent_requests, - ), + config=TGIConfig(model_id=self.config.model, **self.txi_kwargs, **self.tgi_kwargs), ) - elif self.config.task in TEXT_EMBEDDING_TASKS: self.pretrained_model = TEI( - config=TEIConfig( - model_id=self.config.model, - gpus=self.config.gpus, - devices=self.config.devices, - volumes=self.config.volumes, - environment=self.config.environment, - ports=self.config.ports, - dtype=self.config.dtype, - pooling=self.config.pooling, - max_concurrent_requests=self.config.max_concurrent_requests, - ), + config=TEIConfig(model_id=self.config.model, **self.txi_kwargs, **self.tei_kwargs), ) else: raise NotImplementedError(f"TXI does not support task {self.config.task}") + @property + def txi_kwargs(self): + kwargs = {} + + if self.config.gpus is not None: + kwargs["gpus"] = self.config.gpus + + if self.config.image is not None: + kwargs["image"] = self.config.image + + if self.config.ports is not None: + kwargs["ports"] = self.config.ports + + if self.config.volumes is not None: + kwargs["volumes"] = self.config.volumes + + if self.config.devices is not None: + kwargs["devices"] = self.config.devices + + if self.config.shm_size is not None: + kwargs["shm_size"] = self.config.shm_size + + if self.config.environment is not None: + kwargs["environment"] = self.config.environment + + if self.config.connection_timeout is not None: + kwargs["connection_timeout"] = self.config.connection_timeout + + if self.config.first_request_timeout is not None: + kwargs["first_request_timeout"] = self.config.first_request_timeout + + if self.config.max_concurrent_requests is not None: + kwargs["max_concurrent_requests"] = self.config.max_concurrent_requests + + return kwargs + + @property + def tei_kwargs(self): + kwargs = {} + + if self.config.dtype is not None: + kwargs["dtype"] = self.config.dtype + + if self.config.pooling is not None: + kwargs["pooling"] = self.config.pooling + + return kwargs + + @property + def tgi_kwargs(self): + kwargs = {} + + if self.config.dtype is not None: + kwargs["dtype"] = self.config.dtype + + if self.config.sharded is not None: + kwargs["sharded"] = self.config.sharded + + if self.config.quantize is not None: + kwargs["quantize"] = self.config.quantize + + if self.config.num_shard is not None: + kwargs["num_shard"] = self.config.num_shard + + if self.config.speculate is not None: + kwargs["speculate"] = self.config.speculate + + if self.config.cuda_graphs is not None: + kwargs["cuda_graphs"] = self.config.cuda_graphs + + if self.config.trust_remote_code is not None: + kwargs["trust_remote_code"] = self.config.trust_remote_code + + if self.config.disable_custom_kernels is not None: + kwargs["disable_custom_kernels"] = self.config.disable_custom_kernels + + return kwargs + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task in TEXT_GENERATION_TASKS: inputs = {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())} diff --git a/optimum_benchmark/backends/py_txi/config.py b/optimum_benchmark/backends/py_txi/config.py index dae410c4..3b4a908d 100644 --- a/optimum_benchmark/backends/py_txi/config.py +++ b/optimum_benchmark/backends/py_txi/config.py @@ -1,9 +1,7 @@ import os -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE - from ...import_utils import py_txi_version from ...system_utils import is_nvidia_system, is_rocm_system from ...task_utils import TEXT_EMBEDDING_TASKS, TEXT_GENERATION_TASKS @@ -16,35 +14,30 @@ class PyTXIConfig(BackendConfig): version: Optional[str] = py_txi_version() _target_: str = "optimum_benchmark.backends.py_txi.backend.PyTXIBackend" - # optimum benchmark specific + # optimum-benchmark specific no_weights: bool = False # Image to use for the container image: Optional[str] = None # Shared memory size for the container - shm_size: str = "1g" + shm_size: Optional[str] = None # List of custom devices to forward to the container e.g. ["/dev/kfd", "/dev/dri"] for ROCm devices: Optional[List[str]] = None # NVIDIA-docker GPU device options e.g. "all" (all) or "0,1,2,3" (ids) or 4 (count) gpus: Optional[Union[str, int]] = None # Things to forward to the container - ports: Dict[str, Any] = field( - default_factory=lambda: {"80/tcp": ("127.0.0.1", 0)}, - metadata={"help": "Dictionary of ports to expose from the container."}, - ) - volumes: Dict[str, Any] = field( - default_factory=lambda: {HUGGINGFACE_HUB_CACHE: {"bind": "/data", "mode": "rw"}}, - metadata={"help": "Dictionary of volumes to mount inside the container."}, - ) - environment: List[str] = field( - default_factory=lambda: ["HUGGING_FACE_HUB_TOKEN"], - metadata={"help": "List of environment variables to forward to the container from the host."}, - ) + ports: Optional[Dict[str, Any]] = None + environment: Optional[List[str]] = None + volumes: Optional[Dict[str, Any]] = None + # First connection/request + connection_timeout: Optional[int] = None + first_request_timeout: Optional[int] = None + max_concurrent_requests: Optional[int] = None # Common options dtype: Optional[str] = None - max_concurrent_requests: Optional[int] = None - + # TEI specific + pooling: Optional[str] = None # TGI specific sharded: Optional[str] = None quantize: Optional[str] = None @@ -54,9 +47,6 @@ class PyTXIConfig(BackendConfig): trust_remote_code: Optional[bool] = None disable_custom_kernels: Optional[bool] = None - # TEI specific - pooling: Optional[str] = None - def __post_init__(self): super().__post_init__() @@ -72,14 +62,4 @@ def __post_init__(self): renderDs = [file for file in os.listdir("/dev/dri") if file.startswith("renderD")] self.devices = ["/dev/kfd"] + [f"/dev/dri/{renderDs[i]}" for i in ids] - # Common options - if self.max_concurrent_requests is None: - if self.task in TEXT_GENERATION_TASKS: - self.max_concurrent_requests = 128 - elif self.task in TEXT_EMBEDDING_TASKS: - self.max_concurrent_requests = 512 - - # TGI specific - if self.task in TEXT_GENERATION_TASKS: - if self.trust_remote_code is None: - self.trust_remote_code = self.model_kwargs.get("trust_remote_code", False) + self.trust_remote_code = self.model_kwargs.get("trust_remote_code", None) diff --git a/optimum_benchmark/task_utils.py b/optimum_benchmark/task_utils.py index 7c066d14..45e3a342 100644 --- a/optimum_benchmark/task_utils.py +++ b/optimum_benchmark/task_utils.py @@ -155,7 +155,11 @@ def is_local_dir_repo(model_name_or_path: str) -> bool: def get_repo_config( - model_name_or_path: str, config_name: str, token: Optional[str] = None, revision: Optional[str] = None + model_name_or_path: str, + config_name: str, + token: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, ): if is_hf_hub_repo(model_name_or_path, token=token): config = json.load( @@ -163,6 +167,7 @@ def get_repo_config( huggingface_hub.hf_hub_download( repo_id=model_name_or_path, filename=config_name, + cache_dir=cache_dir, revision=revision, token=token, ), @@ -197,6 +202,7 @@ def infer_library_from_model_name_or_path( model_name_or_path: str, token: Optional[str] = None, revision: Optional[str] = None, + cache_dir: Optional[str] = None, ) -> str: inferred_library_name = None @@ -209,7 +215,9 @@ def infer_library_from_model_name_or_path( inferred_library_name = "sentence-transformers" elif "config.json" in repo_files: - config_dict = get_repo_config(model_name_or_path, "config.json", token=token, revision=revision) + config_dict = get_repo_config( + model_name_or_path, "config.json", token=token, revision=revision, cache_dir=cache_dir + ) if "pretrained_cfg" in config_dict: inferred_library_name = "timm" @@ -229,12 +237,15 @@ def infer_task_from_model_name_or_path( model_name_or_path: str, token: Optional[str] = None, revision: Optional[str] = None, + cache_dir: Optional[str] = None, library_name: Optional[str] = None, ) -> str: inferred_task_name = None if library_name is None: - library_name = infer_library_from_model_name_or_path(model_name_or_path, revision=revision, token=token) + library_name = infer_library_from_model_name_or_path( + model_name_or_path, revision=revision, token=token, cache_dir=cache_dir + ) if library_name == "llama_cpp": inferred_task_name = "text-generation" @@ -243,7 +254,9 @@ def infer_task_from_model_name_or_path( inferred_task_name = "image-classification" elif library_name == "transformers": - transformers_config = get_repo_config(model_name_or_path, "config.json", token=token, revision=revision) + transformers_config = get_repo_config( + model_name_or_path, "config.json", token=token, revision=revision, cache_dir=cache_dir + ) target_class_name = transformers_config["architectures"][0] for task_name, model_mapping in TASKS_TO_MODEL_TYPES_TO_MODEL_CLASS_NAMES.items(): @@ -258,7 +271,9 @@ def infer_task_from_model_name_or_path( raise KeyError(f"Could not find the proper task name for target class name {target_class_name}.") elif library_name == "diffusers": - diffusers_config = get_repo_config(model_name_or_path, "model_index.json", token=token, revision=revision) + diffusers_config = get_repo_config( + model_name_or_path, "model_index.json", token=token, revision=revision, cache_dir=cache_dir + ) target_class_name = diffusers_config["_class_name"] for task_name, pipeline_mapping in TASKS_TO_PIPELINE_TYPES_TO_PIPELINE_CLASS_NAMES.items(): @@ -279,26 +294,35 @@ def infer_model_type_from_model_name_or_path( model_name_or_path: str, token: Optional[str] = None, revision: Optional[str] = None, + cache_dir: Optional[str] = None, library_name: Optional[str] = None, ) -> str: inferred_model_type = None if library_name is None: - library_name = infer_library_from_model_name_or_path(model_name_or_path, revision=revision, token=token) + library_name = infer_library_from_model_name_or_path( + model_name_or_path, revision=revision, token=token, cache_dir=cache_dir + ) if library_name == "llama_cpp": inferred_model_type = "llama_cpp" elif library_name == "timm": - timm_config = get_repo_config(model_name_or_path, "config.json", token=token, revision=revision) + timm_config = get_repo_config( + model_name_or_path, "config.json", token=token, revision=revision, cache_dir=cache_dir + ) inferred_model_type = timm_config["architecture"] elif library_name == "transformers": - transformers_config = get_repo_config(model_name_or_path, "config.json", token=token, revision=revision) + transformers_config = get_repo_config( + model_name_or_path, "config.json", token=token, revision=revision, cache_dir=cache_dir + ) inferred_model_type = transformers_config["model_type"] elif library_name == "diffusers": - diffusers_config = get_repo_config(model_name_or_path, "model_index.json", token=token, revision=revision) + diffusers_config = get_repo_config( + model_name_or_path, "model_index.json", token=token, revision=revision, cache_dir=cache_dir + ) target_class_name = diffusers_config["_class_name"] for _, pipeline_mapping in TASKS_TO_PIPELINE_TYPES_TO_PIPELINE_CLASS_NAMES.items(): @@ -310,6 +334,7 @@ def infer_model_type_from_model_name_or_path( break if inferred_model_type is None: - raise KeyError(f"Could not find the proper model type for target class name {target_class_name}.") + # we use the class name in this case + inferred_model_type = target_class_name.replace("DiffusionPipeline", "").replace("Pipeline", "") return inferred_model_type diff --git a/tests/configs/cpu_inference_py_txi_gpt2.yaml b/tests/configs/cpu_inference_py_txi_gpt2.yaml index 76e90775..1aef598e 100644 --- a/tests/configs/cpu_inference_py_txi_gpt2.yaml +++ b/tests/configs/cpu_inference_py_txi_gpt2.yaml @@ -3,6 +3,7 @@ defaults: - _base_ # inherits from base config - _cpu_ # inherits from cpu config - _inference_ # inherits from inference config + - _no_weights_ # inherits from no weights config - _gpt2_ # inherits from gpt2 config - _self_ # hydra 1.1 compatibility - override backend: py-txi diff --git a/tests/configs/cpu_inference_py_txi_st_bert.yaml b/tests/configs/cpu_inference_py_txi_st_bert.yaml index 2650e1bf..99e571b5 100644 --- a/tests/configs/cpu_inference_py_txi_st_bert.yaml +++ b/tests/configs/cpu_inference_py_txi_st_bert.yaml @@ -3,6 +3,7 @@ defaults: - _base_ # inherits from base config - _cpu_ # inherits from cpu config - _inference_ # inherits from inference config + - _no_weights_ # inherits from no weights config - _st_bert_ # inherits from bert config - _self_ # hydra 1.1 compatibility - override backend: py-txi diff --git a/tests/configs/cuda_inference_py_txi_gpt2.yaml b/tests/configs/cuda_inference_py_txi_gpt2.yaml index 73a5c10a..1c93ac36 100644 --- a/tests/configs/cuda_inference_py_txi_gpt2.yaml +++ b/tests/configs/cuda_inference_py_txi_gpt2.yaml @@ -3,6 +3,7 @@ defaults: - _base_ # inherits from base config - _cuda_ # inherits from cuda config - _inference_ # inherits from inference config + - _no_weights_ # inherits from no weights config - _gpt2_ # inherits from gpt2 config - _self_ # hydra 1.1 compatibility - override backend: py-txi diff --git a/tests/configs/cuda_inference_py_txi_st_bert.yaml b/tests/configs/cuda_inference_py_txi_st_bert.yaml index 8ae494e7..5bb38528 100644 --- a/tests/configs/cuda_inference_py_txi_st_bert.yaml +++ b/tests/configs/cuda_inference_py_txi_st_bert.yaml @@ -3,6 +3,7 @@ defaults: - _base_ # inherits from base config - _cuda_ # inherits from cuda config - _inference_ # inherits from inference config + - _no_weights_ # inherits from no weights config - _st_bert_ # inherits from bert config - _self_ # hydra 1.1 compatibility - override backend: py-txi