diff --git a/model_analyzer/config/generate/base_model_config_generator.py b/model_analyzer/config/generate/base_model_config_generator.py index a99aa3310..475295eb0 100755 --- a/model_analyzer/config/generate/base_model_config_generator.py +++ b/model_analyzer/config/generate/base_model_config_generator.py @@ -203,8 +203,6 @@ def make_model_config_variant( model_name, model_config_dict, param_combo ) - model_config_dict["name"] = variant_name - logger.info("") if variant_found: logger.info(f"Found existing model config: {variant_name}") else: @@ -213,6 +211,7 @@ def make_model_config_variant( logger.info(str) logger.info("") + model_config_dict["name"] = variant_name model_config = ModelConfig.create_from_dictionary(model_config_dict) model_config.set_cpu_only(model.cpu_only()) @@ -226,7 +225,8 @@ def make_ensemble_model_config_variant( param_combo: Dict = {}, ) -> ModelConfigVariant: """ - Loads the ensemble model spec from the model repository + Loads the ensemble model spec from the model repository, and then mutates + the names to match the ensemble composing models Parameters ---------- @@ -258,8 +258,6 @@ def make_ensemble_model_config_variant( model_name, ensemble_key ) - model_config_dict["name"] = variant_name - logger.info("") if variant_found: logger.info(f"Found existing ensemble model config: {variant_name}") else: @@ -267,6 +265,7 @@ def make_ensemble_model_config_variant( for str in logger_str: logger.info(str) + model_config_dict["name"] = variant_name model_config = ModelConfig.create_from_dictionary(model_config_dict) return ModelConfigVariant(model_config, variant_name) diff --git a/model_analyzer/config/run/model_run_config.py b/model_analyzer/config/run/model_run_config.py index eb8e2c7a9..49fd3f3c8 100755 --- a/model_analyzer/config/run/model_run_config.py +++ b/model_analyzer/config/run/model_run_config.py @@ -296,7 +296,7 @@ def from_dict(cls, model_run_config_dict): model_run_config_dict["_perf_config"] ) - # FIXME: This is for backward compatibility with older checkpoints used in unit tests + # TODO: TMA-1332: This is for backward compatibility with older checkpoints used in unit tests if "_model_config" in model_run_config_dict: model_config = ModelConfig.from_dict(model_run_config_dict["_model_config"]) model_run_config._model_config_variant = ModelConfigVariant( @@ -311,7 +311,7 @@ def from_dict(cls, model_run_config_dict): ] ] - # FIXME: This is for backward compatibility with older checkpoints used in unit tests + # TODO: TMA-1332: This is for backward compatibility with older checkpoints used in unit tests if "_composing_configs" in model_run_config_dict: composing_configs = [ ModelConfig.from_dict(composing_config_dict) diff --git a/model_analyzer/record/metrics_manager.py b/model_analyzer/record/metrics_manager.py index 0d777b97b..f8c6e0891 100755 --- a/model_analyzer/record/metrics_manager.py +++ b/model_analyzer/record/metrics_manager.py @@ -390,7 +390,7 @@ def _load_model_variants(self, run_config): for composing_config_variant in mrc.composing_configs(): original_composing_config = ( BaseModelConfigGenerator.create_original_config_from_variant( - composing_config + composing_config_variant ) ) if not self._load_model_variant( diff --git a/model_analyzer/triton/model/model_config_variant.py b/model_analyzer/triton/model/model_config_variant.py index a29ba6030..13716d6b6 100644 --- a/model_analyzer/triton/model/model_config_variant.py +++ b/model_analyzer/triton/model/model_config_variant.py @@ -13,7 +13,6 @@ # limitations under the License. from dataclasses import dataclass -from typing import List from model_analyzer.triton.model.model_config import ModelConfig