From 1520bd25d2ee64a62aef4b57e1b6be8e1f24920f Mon Sep 17 00:00:00 2001 From: braf Date: Mon, 14 Aug 2023 15:28:50 +0000 Subject: [PATCH] Fixes for C_API mode --- .../config/generate/base_model_config_generator.py | 10 ++++++++-- .../config/generate/quick_run_config_generator.py | 7 +++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/model_analyzer/config/generate/base_model_config_generator.py b/model_analyzer/config/generate/base_model_config_generator.py index ada906520..f456355e7 100755 --- a/model_analyzer/config/generate/base_model_config_generator.py +++ b/model_analyzer/config/generate/base_model_config_generator.py @@ -70,6 +70,7 @@ def __init__( self._base_model = model self._base_model_name = model.model_name() self._remote_mode = config.triton_launch_mode == "remote" + self._c_api_mode = config.triton_launch_mode == "c_api" self._cpu_only = model.cpu_only() self._default_only = default_only self._early_exit_enable = early_exit_enable @@ -171,6 +172,7 @@ def _make_direct_mode_model_config_variant( param_combo=param_combo, model=self._base_model, model_variant_name_manager=self._model_variant_name_manager, + c_api_mode=self._c_api_mode, ) @staticmethod @@ -178,6 +180,7 @@ def make_model_config_variant( param_combo: dict, model: ModelProfileSpec, model_variant_name_manager: ModelVariantNameManager, + c_api_mode: bool, ) -> ModelConfigVariant: """ Loads the base model config from the model repository, and then applies the @@ -189,6 +192,7 @@ def make_model_config_variant( dict of key:value pairs to apply to the model config model: ModelProfileSpec model_variant_name_manager: ModelVariantNameManager + c_api_mode: Set to true if mode is c_api """ logger_str: List[str] = [] model_name = model.model_name() @@ -211,7 +215,7 @@ def make_model_config_variant( logger.info(str) logger.info("") - model_config_dict["name"] = model_name + model_config_dict["name"] = variant_name if c_api_mode else model_name model_config = ModelConfig.create_from_dictionary(model_config_dict) model_config.set_cpu_only(model.cpu_only()) @@ -222,6 +226,7 @@ def make_ensemble_model_config_variant( model: ModelProfileSpec, ensemble_composing_model_config_variants: List[ModelConfigVariant], model_variant_name_manager: ModelVariantNameManager, + c_api_mode: bool, param_combo: Dict = {}, ) -> ModelConfigVariant: """ @@ -235,6 +240,7 @@ def make_ensemble_model_config_variant( ensemble_composing_model_config_variants: List of ModelConfigVariants The list of composing model ModelConfigs model_variant_name_manager: ModelVariantNameManager + c_api_mode: Set to true if mode is c_api """ logger_str: List[str] = [] @@ -261,7 +267,7 @@ def make_ensemble_model_config_variant( for str in logger_str: logger.info(str) - model_config_dict["name"] = model_name + model_config_dict["name"] = variant_name if c_api_mode else model_name model_config = ModelConfig.create_from_dictionary(model_config_dict) return ModelConfigVariant(model_config, variant_name) diff --git a/model_analyzer/config/generate/quick_run_config_generator.py b/model_analyzer/config/generate/quick_run_config_generator.py index 2d8cf6a3c..5454765bf 100755 --- a/model_analyzer/config/generate/quick_run_config_generator.py +++ b/model_analyzer/config/generate/quick_run_config_generator.py @@ -91,6 +91,8 @@ def __init__( self._triton_env = BruteRunConfigGenerator.determine_triton_server_env(models) + self._c_api_mode = config.triton_launch_mode == "c_api" + # This tracks measured results for all coordinates self._coordinate_data = CoordinateData() @@ -425,6 +427,7 @@ def _get_next_ensemble_model_config_variant( ensemble_composing_model_config_variants=composing_config_variants, model_variant_name_manager=self._model_variant_name_manager, param_combo=param_combo, + c_api_mode=self._c_api_mode, ) ) @@ -471,6 +474,7 @@ def _get_next_model_config_variant( param_combo=param_combo, model=model, model_variant_name_manager=self._model_variant_name_manager, + c_api_mode=self._c_api_mode, ) return model_config_variant @@ -624,6 +628,7 @@ def _create_default_ensemble_model_run_config( model=model, ensemble_composing_model_config_variants=default_composing_model_config_variants, model_variant_name_manager=self._model_variant_name_manager, + c_api_mode=self._c_api_mode, ) default_perf_analyzer_config = self._create_default_perf_analyzer_config( @@ -652,6 +657,7 @@ def _create_default_composing_model_config_variants( param_combo={}, model=composing_model, model_variant_name_manager=self._model_variant_name_manager, + c_api_mode=self._c_api_mode, ) ) @@ -665,6 +671,7 @@ def _create_default_model_run_config( param_combo={}, model=model, model_variant_name_manager=self._model_variant_name_manager, + c_api_mode=self._c_api_mode, ) )