diff --git a/examples/cuda_trt_llama.yaml b/examples/cuda_trt_llama.yaml index c483fc2f..26f35b2c 100644 --- a/examples/cuda_trt_llama.yaml +++ b/examples/cuda_trt_llama.yaml @@ -15,10 +15,11 @@ launcher: backend: device: cuda device_ids: 0 - max_batch_size: 4 - max_new_tokens: 32 - max_prompt_length: 64 + force_export: true model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 + max_prompt_length: 64 + max_new_tokens: 32 + max_batch_size: 4 scenario: input_shapes: diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index a05187c3..f46ce6c8 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -46,6 +46,7 @@ def load_trtmodel_from_pretrained(self) -> None: max_batch_size=self.config.max_batch_size, max_new_tokens=self.config.max_new_tokens, max_beam_width=self.config.max_beam_width, + force_export=self.config.force_export, **self.config.model_kwargs, ) diff --git a/optimum_benchmark/backends/tensorrt_llm/config.py b/optimum_benchmark/backends/tensorrt_llm/config.py index d7f4b1cb..4fc83f11 100644 --- a/optimum_benchmark/backends/tensorrt_llm/config.py +++ b/optimum_benchmark/backends/tensorrt_llm/config.py @@ -18,6 +18,7 @@ class TRTLLMConfig(BackendConfig): pp: int = 1 use_fp8: bool = False dtype: str = "float16" + force_export: bool = False optimization_level: int = 2 use_cuda_graph: bool = False