Skip to content

Commit

Permalink
Fixes for C_API mode
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed Aug 14, 2023
1 parent 19daeda commit 1520bd2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
10 changes: 8 additions & 2 deletions model_analyzer/config/generate/base_model_config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
self._base_model = model
self._base_model_name = model.model_name()
self._remote_mode = config.triton_launch_mode == "remote"
self._c_api_mode = config.triton_launch_mode == "c_api"
self._cpu_only = model.cpu_only()
self._default_only = default_only
self._early_exit_enable = early_exit_enable
Expand Down Expand Up @@ -171,13 +172,15 @@ def _make_direct_mode_model_config_variant(
param_combo=param_combo,
model=self._base_model,
model_variant_name_manager=self._model_variant_name_manager,
c_api_mode=self._c_api_mode,
)

@staticmethod
def make_model_config_variant(
param_combo: dict,
model: ModelProfileSpec,
model_variant_name_manager: ModelVariantNameManager,
c_api_mode: bool,
) -> ModelConfigVariant:
"""
Loads the base model config from the model repository, and then applies the
Expand All @@ -189,6 +192,7 @@ def make_model_config_variant(
dict of key:value pairs to apply to the model config
model: ModelProfileSpec
model_variant_name_manager: ModelVariantNameManager
c_api_mode: Set to true if mode is c_api
"""
logger_str: List[str] = []
model_name = model.model_name()
Expand All @@ -211,7 +215,7 @@ def make_model_config_variant(
logger.info(str)
logger.info("")

model_config_dict["name"] = model_name
model_config_dict["name"] = variant_name if c_api_mode else model_name
model_config = ModelConfig.create_from_dictionary(model_config_dict)
model_config.set_cpu_only(model.cpu_only())

Expand All @@ -222,6 +226,7 @@ def make_ensemble_model_config_variant(
model: ModelProfileSpec,
ensemble_composing_model_config_variants: List[ModelConfigVariant],
model_variant_name_manager: ModelVariantNameManager,
c_api_mode: bool,
param_combo: Dict = {},
) -> ModelConfigVariant:
"""
Expand All @@ -235,6 +240,7 @@ def make_ensemble_model_config_variant(
ensemble_composing_model_config_variants: List of ModelConfigVariants
The list of composing model ModelConfigs
model_variant_name_manager: ModelVariantNameManager
c_api_mode: Set to true if mode is c_api
"""
logger_str: List[str] = []
Expand All @@ -261,7 +267,7 @@ def make_ensemble_model_config_variant(
for str in logger_str:
logger.info(str)

model_config_dict["name"] = model_name
model_config_dict["name"] = variant_name if c_api_mode else model_name
model_config = ModelConfig.create_from_dictionary(model_config_dict)

return ModelConfigVariant(model_config, variant_name)
Expand Down
7 changes: 7 additions & 0 deletions model_analyzer/config/generate/quick_run_config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(

self._triton_env = BruteRunConfigGenerator.determine_triton_server_env(models)

self._c_api_mode = config.triton_launch_mode == "c_api"

# This tracks measured results for all coordinates
self._coordinate_data = CoordinateData()

Expand Down Expand Up @@ -425,6 +427,7 @@ def _get_next_ensemble_model_config_variant(
ensemble_composing_model_config_variants=composing_config_variants,
model_variant_name_manager=self._model_variant_name_manager,
param_combo=param_combo,
c_api_mode=self._c_api_mode,
)
)

Expand Down Expand Up @@ -471,6 +474,7 @@ def _get_next_model_config_variant(
param_combo=param_combo,
model=model,
model_variant_name_manager=self._model_variant_name_manager,
c_api_mode=self._c_api_mode,
)

return model_config_variant
Expand Down Expand Up @@ -624,6 +628,7 @@ def _create_default_ensemble_model_run_config(
model=model,
ensemble_composing_model_config_variants=default_composing_model_config_variants,
model_variant_name_manager=self._model_variant_name_manager,
c_api_mode=self._c_api_mode,
)

default_perf_analyzer_config = self._create_default_perf_analyzer_config(
Expand Down Expand Up @@ -652,6 +657,7 @@ def _create_default_composing_model_config_variants(
param_combo={},
model=composing_model,
model_variant_name_manager=self._model_variant_name_manager,
c_api_mode=self._c_api_mode,
)
)

Expand All @@ -665,6 +671,7 @@ def _create_default_model_run_config(
param_combo={},
model=model,
model_variant_name_manager=self._model_variant_name_manager,
c_api_mode=self._c_api_mode,
)
)

Expand Down

0 comments on commit 1520bd2

Please sign in to comment.