diff --git a/optimum_benchmark/backends/base.py b/optimum_benchmark/backends/base.py index 1c039163..8488b457 100644 --- a/optimum_benchmark/backends/base.py +++ b/optimum_benchmark/backends/base.py @@ -13,14 +13,14 @@ from .config import BackendConfigT from .diffusers_utils import ( extract_diffusers_shapes_from_model, - get_diffusers_automodel_loader_for_task, + get_diffusers_auto_pipeline_class_for_task, get_diffusers_pretrained_config, ) -from .timm_utils import extract_timm_shapes_from_config, get_timm_automodel_loader, get_timm_pretrained_config +from .timm_utils import extract_timm_shapes_from_config, get_timm_model_creator, get_timm_pretrained_config from .transformers_utils import ( PretrainedProcessor, extract_transformers_shapes_from_artifacts, - get_transformers_automodel_loader_for_task, + get_transformers_auto_model_class_for_task, get_transformers_generation_config, get_transformers_pretrained_config, get_transformers_pretrained_processor, @@ -56,7 +56,7 @@ def __init__(self, config: BackendConfigT): self.logger.info("\t+ Benchmarking a Diffusers pipeline") self.pretrained_config = get_diffusers_pretrained_config(self.config.model, **self.config.model_kwargs) self.model_shapes = extract_diffusers_shapes_from_model(self.config.model, **self.config.model_kwargs) - self.automodel_loader = get_diffusers_automodel_loader_for_task(self.config.task) + self.automodel_loader = get_diffusers_auto_pipeline_class_for_task(self.config.task) self.pretrained_processor = None self.generation_config = None @@ -64,7 +64,7 @@ def __init__(self, config: BackendConfigT): self.logger.info("\t+ Benchmarking a Timm model") self.pretrained_config = get_timm_pretrained_config(self.config.model) self.model_shapes = extract_timm_shapes_from_config(self.pretrained_config) - self.automodel_loader = get_timm_automodel_loader() + self.automodel_loader = get_timm_model_creator() self.pretrained_processor = None self.generation_config = None @@ -78,7 +78,7 @@ def __init__(self, config: BackendConfigT): else: self.logger.info("\t+ Benchmarking a Transformers model") - self.automodel_loader = get_transformers_automodel_loader_for_task(self.config.task, self.config.model_type) + self.automodel_loader = get_transformers_auto_model_class_for_task(self.config.task, self.config.model_type) self.generation_config = get_transformers_generation_config(self.config.model, **self.config.model_kwargs) self.pretrained_config = get_transformers_pretrained_config(self.config.model, **self.config.model_kwargs) self.pretrained_processor = get_transformers_pretrained_processor( diff --git a/optimum_benchmark/backends/diffusers_utils.py b/optimum_benchmark/backends/diffusers_utils.py index 43f0757b..ef1b4a59 100644 --- a/optimum_benchmark/backends/diffusers_utils.py +++ b/optimum_benchmark/backends/diffusers_utils.py @@ -9,33 +9,16 @@ import diffusers from diffusers import DiffusionPipeline - if hasattr(diffusers, "pipelines") and hasattr(diffusers.pipelines, "auto_pipeline"): - from diffusers.pipelines.auto_pipeline import ( - AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, - AUTO_INPAINT_PIPELINES_MAPPING, - AUTO_TEXT2IMAGE_PIPELINES_MAPPING, - ) - - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = { - "inpainting": AUTO_INPAINT_PIPELINES_MAPPING.copy(), - "text-to-image": AUTO_TEXT2IMAGE_PIPELINES_MAPPING.copy(), - "image-to-image": AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.copy(), - } - - for task_name, model_mapping in TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items(): - for model_type, model_class in model_mapping.items(): - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name][model_type] = model_class.__name__ - else: - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {} -else: - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {} +def get_diffusers_auto_pipeline_class_for_task(task: str): + from ..task_utils import TASKS_TO_AUTO_PIPELINE_CLASS_NAMES -TASKS_TO_MODEL_LOADERS = { - "inpainting": "AutoPipelineForInpainting", - "text-to-image": "AutoPipelineForText2Image", - "image-to-image": "AutoPipelineForImage2Image", -} + if not is_diffusers_available(): + raise ImportError("diffusers is not available. Please, pip install diffusers.") + + model_loader_name = TASKS_TO_AUTO_PIPELINE_CLASS_NAMES.get(task, None) + model_loader_class = getattr(diffusers, model_loader_name) + return model_loader_class def get_diffusers_pretrained_config(model: str, **kwargs) -> Dict[str, int]: @@ -85,12 +68,3 @@ def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]: shapes["width"] = -1 return shapes - - -def get_diffusers_automodel_loader_for_task(task: str): - if not is_diffusers_available(): - raise ImportError("diffusers is not available. Please, pip install diffusers.") - - model_loader_name = TASKS_TO_MODEL_LOADERS[task] - model_loader_class = getattr(diffusers, model_loader_name) - return model_loader_class diff --git a/optimum_benchmark/backends/timm_utils.py b/optimum_benchmark/backends/timm_utils.py index dbaf36fd..4cb3cd1c 100644 --- a/optimum_benchmark/backends/timm_utils.py +++ b/optimum_benchmark/backends/timm_utils.py @@ -10,6 +10,13 @@ from timm.models import get_pretrained_cfg, load_model_config_from_hf, parse_model_name +def get_timm_model_creator(): + if not is_timm_available(): + raise ImportError("timm is not available. Please, pip install timm.") + + return create_model + + def get_timm_pretrained_config(model_name: str) -> PretrainedConfig: if not is_timm_available(): raise ImportError("timm is not available. Please, pip install timm.") @@ -71,10 +78,3 @@ def extract_timm_shapes_from_config(config: PretrainedConfig) -> Dict[str, Any]: warnings.warn("Could not extract shapes [num_channels, height, width] from timm model config.") return shapes - - -def get_timm_automodel_loader(): - if not is_timm_available(): - raise ImportError("timm is not available. Please, pip install timm.") - - return create_model diff --git a/optimum_benchmark/backends/transformers_utils.py b/optimum_benchmark/backends/transformers_utils.py index efd2b8af..7226dd7c 100644 --- a/optimum_benchmark/backends/transformers_utils.py +++ b/optimum_benchmark/backends/transformers_utils.py @@ -18,59 +18,23 @@ SpecialTokensMixin, ) -TASKS_TO_AUTOMODEL_CLASS_NAMES = { - # text processing - "feature-extraction": "AutoModel", - "fill-mask": "AutoModelForMaskedLM", - "multiple-choice": "AutoModelForMultipleChoice", - "question-answering": "AutoModelForQuestionAnswering", - "token-classification": "AutoModelForTokenClassification", - "text-classification": "AutoModelForSequenceClassification", - # audio processing - "audio-xvector": "AutoModelForAudioXVector", - "text-to-audio": "AutoModelForTextToSpectrogram", - "audio-classification": "AutoModelForAudioClassification", - "audio-frame-classification": "AutoModelForAudioFrameClassification", - # image processing - "mask-generation": "AutoModel", - "image-to-image": "AutoModelForImageToImage", - "masked-im": "AutoModelForMaskedImageModeling", - "object-detection": "AutoModelForObjectDetection", - "depth-estimation": "AutoModelForDepthEstimation", - "image-segmentation": "AutoModelForImageSegmentation", - "image-classification": "AutoModelForImageClassification", - "semantic-segmentation": "AutoModelForSemanticSegmentation", - "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", - "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", - # text generation - "image-to-text": "AutoModelForVision2Seq", - "text-generation": "AutoModelForCausalLM", - "text2text-generation": "AutoModelForSeq2SeqLM", - "image-text-to-text": "AutoModelForImageTextToText", - "visual-question-answering": "AutoModelForVisualQuestionAnswering", - "automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"), -} - -SYNONYM_TASKS = { - "summarization": "text2text-generation", - "sentence-similarity": "feature-extraction", -} +def get_transformers_auto_model_class_for_task(task: str, model_type: Optional[str] = None) -> Type["AutoModel"]: + from ..task_utils import SYNONYM_TASKS, TASKS_TO_AUTO_MODEL_CLASS_NAMES -def get_transformers_automodel_class_for_task(task: str, model_type: Optional[str] = None) -> Type["AutoModel"]: if task in SYNONYM_TASKS: task = SYNONYM_TASKS[task] - if task not in TASKS_TO_AUTOMODEL_CLASS_NAMES: + if task not in TASKS_TO_AUTO_MODEL_CLASS_NAMES: raise ValueError(f"Task {task} not supported") - if isinstance(TASKS_TO_AUTOMODEL_CLASS_NAMES[task], str): - return getattr(transformers, TASKS_TO_AUTOMODEL_CLASS_NAMES[task]) + if isinstance(TASKS_TO_AUTO_MODEL_CLASS_NAMES[task], str): + return getattr(transformers, TASKS_TO_AUTO_MODEL_CLASS_NAMES[task]) else: if model_type is None: raise ValueError(f"Task {task} requires a model_type to be specified") - for automodel_class_name in TASKS_TO_AUTOMODEL_CLASS_NAMES[task]: + for automodel_class_name in TASKS_TO_AUTO_MODEL_CLASS_NAMES[task]: automodel_class = getattr(transformers, automodel_class_name) if model_type in automodel_class._model_mapping._model_mapping: return automodel_class diff --git a/optimum_benchmark/scenarios/inference/scenario.py b/optimum_benchmark/scenarios/inference/scenario.py index e05cb7b9..c7faffed 100644 --- a/optimum_benchmark/scenarios/inference/scenario.py +++ b/optimum_benchmark/scenarios/inference/scenario.py @@ -66,15 +66,17 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: self.logger.info("\t+ Updating Text Generation kwargs with default values") self.config.generate_kwargs = {**TEXT_GENERATION_DEFAULT_KWARGS, **self.config.generate_kwargs} self.logger.info("\t+ Initializing Text Generation report") - self.report = BenchmarkReport.from_list(targets=["load", "prefill", "decode", "per_token"]) + self.report = BenchmarkReport.from_list(targets=["load_model", "prefill", "decode", "per_token"]) elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: self.logger.info("\t+ Updating Image Diffusion kwargs with default values") self.config.call_kwargs = {**IMAGE_DIFFUSION_DEFAULT_KWARGS, **self.config.call_kwargs} self.logger.info("\t+ Initializing Image Diffusion report") - self.report = BenchmarkReport.from_list(targets=["load", "call"]) + self.report = BenchmarkReport.from_list(targets=["load_model", "call"]) else: self.logger.info("\t+ Initializing Inference report") - self.report = BenchmarkReport.from_list(targets=["load", "forward"]) + self.report = BenchmarkReport.from_list(targets=["load_model", "forward"]) + + self.run_model_loading_tracking(backend) self.logger.info("\t+ Creating input generator") self.input_generator = InputGenerator( @@ -83,15 +85,11 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: input_shapes=self.config.input_shapes, model_type=backend.config.model_type, ) - self.logger.info("\t+ Generating inputs") self.inputs = self.input_generator() - - self.logger.info("\t+ Preparing inputs for Inference") + self.logger.info("\t+ Preparing inputs for backend") self.inputs = backend.prepare_inputs(inputs=self.inputs) - self.run_model_loading_tracking(backend) - if self.config.latency or self.config.energy: # latency and energy are metrics that require some warmup if self.config.warmup_runs > 0: @@ -159,8 +157,14 @@ def run_model_loading_tracking(self, backend: Backend[BackendConfigT]): ) if self.config.latency: latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) + if self.config.energy: + energy_tracker = EnergyTracker( + backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids + ) with ExitStack() as context_stack: + if self.config.energy: + context_stack.enter_context(energy_tracker.track()) if self.config.memory: context_stack.enter_context(memory_tracker.track()) if self.config.latency: @@ -169,9 +173,11 @@ def run_model_loading_tracking(self, backend: Backend[BackendConfigT]): backend.load() if self.config.latency: - self.report.load.latency = latency_tracker.get_latency() + self.report.load_model.latency = latency_tracker.get_latency() if self.config.memory: - self.report.load.memory = memory_tracker.get_max_memory() + self.report.load_model.memory = memory_tracker.get_max_memory() + if self.config.energy: + self.report.load_model.energy = energy_tracker.get_energy() ## Memory tracking def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]): diff --git a/optimum_benchmark/task_utils.py b/optimum_benchmark/task_utils.py index 0a2a98c2..1821b47d 100644 --- a/optimum_benchmark/task_utils.py +++ b/optimum_benchmark/task_utils.py @@ -5,38 +5,94 @@ import huggingface_hub -from .backends.diffusers_utils import ( - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES as DIFFUSERS_TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES, -) -from .backends.diffusers_utils import ( - get_diffusers_pretrained_config, -) +from .backends.diffusers_utils import get_diffusers_pretrained_config from .backends.timm_utils import get_timm_pretrained_config -from .backends.transformers_utils import ( - TASKS_TO_MODEL_LOADERS, - get_transformers_pretrained_config, -) -from .backends.transformers_utils import ( - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES as TRANSFORMERS_TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES, -) - -_SYNONYM_TASK_MAP = { - "masked-lm": "fill-mask", - "causal-lm": "text-generation", - "default": "feature-extraction", - "vision2seq-lm": "image-to-text", - "text-to-speech": "text-to-audio", - "seq2seq-lm": "text2text-generation", - "translation": "text2text-generation", - "summarization": "text2text-generation", - "mask-generation": "feature-extraction", - "audio-ctc": "automatic-speech-recognition", - "sentence-similarity": "feature-extraction", - "speech2seq-lm": "automatic-speech-recognition", - "sequence-classification": "text-classification", - "zero-shot-classification": "text-classification", +from .backends.transformers_utils import get_transformers_pretrained_config +from .import_utils import is_diffusers_available, is_torch_available + +TASKS_TO_AUTO_MODEL_CLASS_NAMES = { + # text processing + "feature-extraction": "AutoModel", + "fill-mask": "AutoModelForMaskedLM", + "multiple-choice": "AutoModelForMultipleChoice", + "question-answering": "AutoModelForQuestionAnswering", + "token-classification": "AutoModelForTokenClassification", + "text-classification": "AutoModelForSequenceClassification", + # audio processing + "audio-xvector": "AutoModelForAudioXVector", + "text-to-audio": "AutoModelForTextToSpectrogram", + "audio-classification": "AutoModelForAudioClassification", + "audio-frame-classification": "AutoModelForAudioFrameClassification", + # image processing + "mask-generation": "AutoModel", + "image-to-image": "AutoModelForImageToImage", + "masked-im": "AutoModelForMaskedImageModeling", + "object-detection": "AutoModelForObjectDetection", + "depth-estimation": "AutoModelForDepthEstimation", + "image-segmentation": "AutoModelForImageSegmentation", + "image-classification": "AutoModelForImageClassification", + "semantic-segmentation": "AutoModelForSemanticSegmentation", + "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", + "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", + # text generation + "image-to-text": "AutoModelForVision2Seq", + "text-generation": "AutoModelForCausalLM", + "text2text-generation": "AutoModelForSeq2SeqLM", + "image-text-to-text": "AutoModelForImageTextToText", + "visual-question-answering": "AutoModelForVisualQuestionAnswering", + "automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"), +} + +TASKS_TO_AUTO_PIPELINE_CLASS_NAMES = { + "inpainting": "AutoPipelineForInpainting", + "text-to-image": "AutoPipelineForText2Image", + "image-to-image": "AutoPipelineForImage2Image", } +TASKS_TO_MODEL_TYPES_TO_MODEL_CLASS_NAMES = {} + +if is_torch_available(): + import transformers + + for task_name, auto_model_class_names in TASKS_TO_AUTO_MODEL_CLASS_NAMES.items(): + TASKS_TO_MODEL_TYPES_TO_MODEL_CLASS_NAMES[task_name] = {} + + if isinstance(auto_model_class_names, str): + auto_model_class_names = (auto_model_class_names,) + + for auto_model_class_name in auto_model_class_names: + auto_model_class = getattr(transformers, auto_model_class_name, None) + if auto_model_class is not None: + TASKS_TO_MODEL_TYPES_TO_MODEL_CLASS_NAMES[task_name].update( + auto_model_class._model_mapping._model_mapping + ) + + +TASKS_TO_PIPELINE_TYPES_TO_PIPELINE_CLASS_NAMES = {} + +if is_diffusers_available(): + import diffusers + + if hasattr(diffusers, "pipelines") and hasattr(diffusers.pipelines, "auto_pipeline"): + from diffusers.pipelines.auto_pipeline import ( + AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, + AUTO_INPAINT_PIPELINES_MAPPING, + AUTO_TEXT2IMAGE_PIPELINES_MAPPING, + ) + + TASKS_TO_PIPELINE_TYPES_TO_PIPELINE_CLASS_NAMES = { + "inpainting": AUTO_INPAINT_PIPELINES_MAPPING.copy(), + "text-to-image": AUTO_TEXT2IMAGE_PIPELINES_MAPPING.copy(), + "image-to-image": AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.copy(), + } + + for task_name, pipeline_mapping in TASKS_TO_PIPELINE_TYPES_TO_PIPELINE_CLASS_NAMES.items(): + for pipeline_type, pipeline_class in pipeline_mapping.items(): + TASKS_TO_PIPELINE_TYPES_TO_PIPELINE_CLASS_NAMES[task_name][pipeline_type] = pipeline_class.__name__ + else: + TASKS_TO_PIPELINE_TYPES_TO_PIPELINE_CLASS_NAMES = {} + + IMAGE_DIFFUSION_TASKS = [ "inpainting", "text-to-image", @@ -56,15 +112,34 @@ "feature-extraction", ] +SYNONYM_TASKS = { + "masked-lm": "fill-mask", + "causal-lm": "text-generation", + "default": "feature-extraction", + "vision2seq-lm": "image-to-text", + "text-to-speech": "text-to-audio", + "seq2seq-lm": "text2text-generation", + "translation": "text2text-generation", + "summarization": "text2text-generation", + "mask-generation": "feature-extraction", + "audio-ctc": "automatic-speech-recognition", + "sentence-similarity": "feature-extraction", + "speech2seq-lm": "automatic-speech-recognition", + "sequence-classification": "text-classification", + "zero-shot-classification": "text-classification", +} + def map_from_synonym(task: str) -> str: - if task in _SYNONYM_TASK_MAP: - task = _SYNONYM_TASK_MAP[task] + if task in SYNONYM_TASKS: + task = SYNONYM_TASKS[task] return task def infer_library_from_model_name_or_path( - model_name_or_path: str, revision: Optional[str] = None, token: Optional[str] = None + model_name_or_path: str, + token: Optional[str] = None, + revision: Optional[str] = None, ) -> str: inferred_library_name = None @@ -77,6 +152,18 @@ def infer_library_from_model_name_or_path( repo_files = huggingface_hub.list_repo_files(model_name_or_path, revision=revision, token=token) if "model_index.json" in repo_files: inferred_library_name = "diffusers" + elif "config.json" in repo_files: + config_dict = json.loads( + huggingface_hub.hf_hub_download( + repo_id=model_name_or_path, filename="config.json", revision=revision, token=token + ) + ) + if "pretrained_cfg" in config_dict or "architecture" in config_dict: + inferred_library_name = "timm" + elif "_diffusers_version" in config_dict: + inferred_library_name = "diffusers" + else: + inferred_library_name = "transformers" if inferred_library_name is None: raise RuntimeError(f"Could not infer library name from repo {model_name_or_path}.") @@ -89,6 +176,7 @@ def infer_library_from_model_name_or_path( inferred_library_name = "diffusers" elif "config.json" in local_files: config_dict = json.load(open(os.path.join(model_name_or_path, "config.json"), "r")) + if "pretrained_cfg" in config_dict or "architecture" in config_dict: inferred_library_name = "timm" elif "_diffusers_version" in config_dict: @@ -129,14 +217,36 @@ def infer_task_from_model_name_or_path( elif library_name == "sentence-transformers": inferred_task_name = "feature-extraction" + elif huggingface_hub.repo_exists(model_name_or_path, token=token): + model_info = huggingface_hub.model_info(model_name_or_path, revision=revision, token=token) + + if model_info.pipeline_tag is not None: + inferred_task_name = map_from_synonym(model_info.pipeline_tag) + + elif inferred_task_name is None: + if model_info.transformers_info is not None and model_info.transformersInfo.pipeline_tag is not None: + inferred_task_name = map_from_synonym(model_info.transformersInfo.pipeline_tag) + else: + target_auto_model = model_info.transformers_info["auto_model"] + for task_name, auto_model_class_names in TASKS_TO_AUTO_MODEL_CLASS_NAMES.items(): + if isinstance(auto_model_class_names, str): + auto_model_class_names = (auto_model_class_names,) + + for auto_model_class_name in auto_model_class_names: + if target_auto_model == auto_model_class_name: + inferred_task_name = task_name + break + if inferred_task_name is not None: + break + elif os.path.isdir(model_name_or_path): if library_name == "diffusers": diffusers_config = get_diffusers_pretrained_config(model_name_or_path, revision=revision, token=token) - class_name = diffusers_config["_class_name"] + target_class_name = diffusers_config["_class_name"] - for task_name, model_mapping in DIFFUSERS_TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items(): - for model_type, model_class_name in model_mapping.items(): - if class_name == model_class_name: + for task_name, pipeline_mapping in TASKS_TO_PIPELINE_TYPES_TO_PIPELINE_CLASS_NAMES.items(): + for _, pipeline_class_name in pipeline_mapping.items(): + if target_class_name == pipeline_class_name: inferred_task_name = task_name break if inferred_task_name is not None: @@ -147,7 +257,7 @@ def infer_task_from_model_name_or_path( auto_modeling_module = importlib.import_module("transformers.models.auto.modeling_auto") model_type = transformers_config.model_type - for task_name, model_loaders in TRANSFORMERS_TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items(): + for task_name, model_loaders in TASKS_TO_MODEL_TYPES_TO_MODEL_CLASS_NAMES.items(): if isinstance(model_loaders, str): model_loaders = (model_loaders,) for model_loader in model_loaders: @@ -159,27 +269,6 @@ def infer_task_from_model_name_or_path( if inferred_task_name is not None: break - elif huggingface_hub.repo_exists(model_name_or_path, token=token): - model_info = huggingface_hub.model_info(model_name_or_path, revision=revision, token=token) - - if model_info.pipeline_tag is not None: - inferred_task_name = map_from_synonym(model_info.pipeline_tag) - - elif inferred_task_name is None: - if model_info.transformers_info is not None and model_info.transformersInfo.pipeline_tag is not None: - inferred_task_name = map_from_synonym(model_info.transformersInfo.pipeline_tag) - else: - auto_model_class_name = model_info.transformers_info["auto_model"] - for task_name, model_loaders in TASKS_TO_MODEL_LOADERS.items(): - if isinstance(model_loaders, str): - model_loaders = (model_loaders,) - for model_loader in model_loaders: - if auto_model_class_name == model_loader: - inferred_task_name = task_name - break - if inferred_task_name is not None: - break - if inferred_task_name is None: raise KeyError(f"Could not find the proper task name for {auto_model_class_name}.") @@ -207,12 +296,12 @@ def infer_model_type_from_model_name_or_path( elif library_name == "diffusers": config = get_diffusers_pretrained_config(model_name_or_path, revision=revision, token=token) - class_name = config["_class_name"] + target_class_name = config["_class_name"] - for task_name, model_mapping in DIFFUSERS_TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items(): - for model_type, model_class_name in model_mapping.items(): - if model_class_name == class_name: - inferred_model_type = model_type + for _, pipeline_mapping in TASKS_TO_PIPELINE_TYPES_TO_PIPELINE_CLASS_NAMES.items(): + for pipeline_type, pipeline_class_name in pipeline_mapping.items(): + if target_class_name == pipeline_class_name: + inferred_model_type = pipeline_type break if inferred_model_type is not None: break diff --git a/tests/configs/_st_bert_.yaml b/tests/configs/_st_bert_.yaml new file mode 100644 index 00000000..05ef4026 --- /dev/null +++ b/tests/configs/_st_bert_.yaml @@ -0,0 +1,3 @@ +backend: + model: sentence-transformers/all-MiniLM-L6-v2 + task: feature-extraction diff --git a/tests/configs/cpu_inference_py_txi_bert.yaml b/tests/configs/cpu_inference_py_txi_st_bert.yaml similarity index 77% rename from tests/configs/cpu_inference_py_txi_bert.yaml rename to tests/configs/cpu_inference_py_txi_st_bert.yaml index a575be99..2650e1bf 100644 --- a/tests/configs/cpu_inference_py_txi_bert.yaml +++ b/tests/configs/cpu_inference_py_txi_st_bert.yaml @@ -3,8 +3,8 @@ defaults: - _base_ # inherits from base config - _cpu_ # inherits from cpu config - _inference_ # inherits from inference config - - _bert_ # inherits from bert config + - _st_bert_ # inherits from bert config - _self_ # hydra 1.1 compatibility - override backend: py-txi -name: cpu_inference_py_txi_bert +name: cpu_inference_py_txi_st_bert diff --git a/tests/configs/cuda_inference_py_txi_bert.yaml b/tests/configs/cuda_inference_py_txi_st_bert.yaml similarity index 77% rename from tests/configs/cuda_inference_py_txi_bert.yaml rename to tests/configs/cuda_inference_py_txi_st_bert.yaml index 62405f30..8ae494e7 100644 --- a/tests/configs/cuda_inference_py_txi_bert.yaml +++ b/tests/configs/cuda_inference_py_txi_st_bert.yaml @@ -3,8 +3,8 @@ defaults: - _base_ # inherits from base config - _cuda_ # inherits from cuda config - _inference_ # inherits from inference config - - _bert_ # inherits from bert config + - _st_bert_ # inherits from bert config - _self_ # hydra 1.1 compatibility - override backend: py-txi -name: cuda_inference_py_txi_bert +name: cuda_inference_py_txi_st_bert diff --git a/tests/test_api.py b/tests/test_api.py index fd6e2dac..01851c34 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -22,8 +22,6 @@ from optimum_benchmark.generators.dataset_generator import DatasetGenerator from optimum_benchmark.generators.input_generator import InputGenerator from optimum_benchmark.import_utils import get_git_revision_hash -from optimum_benchmark.scenarios.inference.config import INPUT_SHAPES -from optimum_benchmark.scenarios.training.config import DATASET_SHAPES from optimum_benchmark.system_utils import is_nvidia_system, is_rocm_system from optimum_benchmark.trackers import LatencyTracker, MemoryTracker @@ -40,6 +38,18 @@ ("diffusers", "text-to-image", "CompVis/stable-diffusion-v1-4"), ] +INPUT_SHAPES = { + "batch_size": 2, # for all tasks + "sequence_length": 16, # for text processing tasks + "num_choices": 2, # for multiple-choice task +} + +DATASET_SHAPES = { + "dataset_size": 2, # for all tasks + "sequence_length": 16, # for text processing tasks + "num_choices": 2, # for multiple-choice task +} + @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("scenario", ["training", "inference"]) @@ -47,9 +57,6 @@ def test_api_launch(device, scenario, library, task, model): benchmark_name = f"{device}_{scenario}_{library}_{task}_{model}" - if task == "multiple-choice": - INPUT_SHAPES["num_choices"] = 2 - if device == "cuda": device_isolation = True if is_rocm_system(): @@ -173,9 +180,6 @@ def test_api_input_generator(library, task, model): else: raise ValueError(f"Unknown library {library}") - if task == "multiple-choice": - INPUT_SHAPES["num_choices"] = 2 - input_generator = InputGenerator( task=task, input_shapes=INPUT_SHAPES,