Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 27, 2024
1 parent 712d851 commit b157b89
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 152 deletions.
12 changes: 6 additions & 6 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from .config import BackendConfigT
from .diffusers_utils import (
extract_diffusers_shapes_from_model,
get_diffusers_automodel_loader_for_task,
get_diffusers_auto_pipeline_class_for_task,
get_diffusers_pretrained_config,
)
from .timm_utils import extract_timm_shapes_from_config, get_timm_automodel_loader, get_timm_pretrained_config
from .timm_utils import extract_timm_shapes_from_config, get_timm_model_creator, get_timm_pretrained_config
from .transformers_utils import (
PretrainedProcessor,
extract_transformers_shapes_from_artifacts,
get_transformers_automodel_loader_for_task,
get_transformers_auto_model_class_for_task,
get_transformers_generation_config,
get_transformers_pretrained_config,
get_transformers_pretrained_processor,
Expand Down Expand Up @@ -56,15 +56,15 @@ def __init__(self, config: BackendConfigT):
self.logger.info("\t+ Benchmarking a Diffusers pipeline")
self.pretrained_config = get_diffusers_pretrained_config(self.config.model, **self.config.model_kwargs)
self.model_shapes = extract_diffusers_shapes_from_model(self.config.model, **self.config.model_kwargs)
self.automodel_loader = get_diffusers_automodel_loader_for_task(self.config.task)
self.automodel_loader = get_diffusers_auto_pipeline_class_for_task(self.config.task)
self.pretrained_processor = None
self.generation_config = None

elif self.config.library == "timm":
self.logger.info("\t+ Benchmarking a Timm model")
self.pretrained_config = get_timm_pretrained_config(self.config.model)
self.model_shapes = extract_timm_shapes_from_config(self.pretrained_config)
self.automodel_loader = get_timm_automodel_loader()
self.automodel_loader = get_timm_model_creator()
self.pretrained_processor = None
self.generation_config = None

Expand All @@ -78,7 +78,7 @@ def __init__(self, config: BackendConfigT):

else:
self.logger.info("\t+ Benchmarking a Transformers model")
self.automodel_loader = get_transformers_automodel_loader_for_task(self.config.task, self.config.model_type)
self.automodel_loader = get_transformers_auto_model_class_for_task(self.config.task, self.config.model_type)
self.generation_config = get_transformers_generation_config(self.config.model, **self.config.model_kwargs)
self.pretrained_config = get_transformers_pretrained_config(self.config.model, **self.config.model_kwargs)
self.pretrained_processor = get_transformers_pretrained_processor(
Expand Down
42 changes: 8 additions & 34 deletions optimum_benchmark/backends/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,16 @@
import diffusers
from diffusers import DiffusionPipeline

if hasattr(diffusers, "pipelines") and hasattr(diffusers.pipelines, "auto_pipeline"):
from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
)

TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {
"inpainting": AUTO_INPAINT_PIPELINES_MAPPING.copy(),
"text-to-image": AUTO_TEXT2IMAGE_PIPELINES_MAPPING.copy(),
"image-to-image": AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.copy(),
}

for task_name, model_mapping in TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items():
for model_type, model_class in model_mapping.items():
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name][model_type] = model_class.__name__
else:
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {}
else:
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {}

def get_diffusers_auto_pipeline_class_for_task(task: str):
from ..task_utils import TASKS_TO_AUTO_PIPELINE_CLASS_NAMES

TASKS_TO_MODEL_LOADERS = {
"inpainting": "AutoPipelineForInpainting",
"text-to-image": "AutoPipelineForText2Image",
"image-to-image": "AutoPipelineForImage2Image",
}
if not is_diffusers_available():
raise ImportError("diffusers is not available. Please, pip install diffusers.")

model_loader_name = TASKS_TO_AUTO_PIPELINE_CLASS_NAMES.get(task, None)
model_loader_class = getattr(diffusers, model_loader_name)
return model_loader_class


def get_diffusers_pretrained_config(model: str, **kwargs) -> Dict[str, int]:
Expand Down Expand Up @@ -85,12 +68,3 @@ def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]:
shapes["width"] = -1

return shapes


def get_diffusers_automodel_loader_for_task(task: str):
if not is_diffusers_available():
raise ImportError("diffusers is not available. Please, pip install diffusers.")

model_loader_name = TASKS_TO_MODEL_LOADERS[task]
model_loader_class = getattr(diffusers, model_loader_name)
return model_loader_class
14 changes: 7 additions & 7 deletions optimum_benchmark/backends/timm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from timm.models import get_pretrained_cfg, load_model_config_from_hf, parse_model_name


def get_timm_model_creator():
if not is_timm_available():
raise ImportError("timm is not available. Please, pip install timm.")

return create_model


def get_timm_pretrained_config(model_name: str) -> PretrainedConfig:
if not is_timm_available():
raise ImportError("timm is not available. Please, pip install timm.")
Expand Down Expand Up @@ -71,10 +78,3 @@ def extract_timm_shapes_from_config(config: PretrainedConfig) -> Dict[str, Any]:
warnings.warn("Could not extract shapes [num_channels, height, width] from timm model config.")

return shapes


def get_timm_automodel_loader():
if not is_timm_available():
raise ImportError("timm is not available. Please, pip install timm.")

return create_model
48 changes: 6 additions & 42 deletions optimum_benchmark/backends/transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,59 +18,23 @@
SpecialTokensMixin,
)

TASKS_TO_AUTOMODEL_CLASS_NAMES = {
# text processing
"feature-extraction": "AutoModel",
"fill-mask": "AutoModelForMaskedLM",
"multiple-choice": "AutoModelForMultipleChoice",
"question-answering": "AutoModelForQuestionAnswering",
"token-classification": "AutoModelForTokenClassification",
"text-classification": "AutoModelForSequenceClassification",
# audio processing
"audio-xvector": "AutoModelForAudioXVector",
"text-to-audio": "AutoModelForTextToSpectrogram",
"audio-classification": "AutoModelForAudioClassification",
"audio-frame-classification": "AutoModelForAudioFrameClassification",
# image processing
"mask-generation": "AutoModel",
"image-to-image": "AutoModelForImageToImage",
"masked-im": "AutoModelForMaskedImageModeling",
"object-detection": "AutoModelForObjectDetection",
"depth-estimation": "AutoModelForDepthEstimation",
"image-segmentation": "AutoModelForImageSegmentation",
"image-classification": "AutoModelForImageClassification",
"semantic-segmentation": "AutoModelForSemanticSegmentation",
"zero-shot-object-detection": "AutoModelForZeroShotObjectDetection",
"zero-shot-image-classification": "AutoModelForZeroShotImageClassification",
# text generation
"image-to-text": "AutoModelForVision2Seq",
"text-generation": "AutoModelForCausalLM",
"text2text-generation": "AutoModelForSeq2SeqLM",
"image-text-to-text": "AutoModelForImageTextToText",
"visual-question-answering": "AutoModelForVisualQuestionAnswering",
"automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"),
}

SYNONYM_TASKS = {
"summarization": "text2text-generation",
"sentence-similarity": "feature-extraction",
}

def get_transformers_auto_model_class_for_task(task: str, model_type: Optional[str] = None) -> Type["AutoModel"]:
from ..task_utils import SYNONYM_TASKS, TASKS_TO_AUTO_MODEL_CLASS_NAMES

def get_transformers_automodel_class_for_task(task: str, model_type: Optional[str] = None) -> Type["AutoModel"]:
if task in SYNONYM_TASKS:
task = SYNONYM_TASKS[task]

if task not in TASKS_TO_AUTOMODEL_CLASS_NAMES:
if task not in TASKS_TO_AUTO_MODEL_CLASS_NAMES:
raise ValueError(f"Task {task} not supported")

if isinstance(TASKS_TO_AUTOMODEL_CLASS_NAMES[task], str):
return getattr(transformers, TASKS_TO_AUTOMODEL_CLASS_NAMES[task])
if isinstance(TASKS_TO_AUTO_MODEL_CLASS_NAMES[task], str):
return getattr(transformers, TASKS_TO_AUTO_MODEL_CLASS_NAMES[task])
else:
if model_type is None:
raise ValueError(f"Task {task} requires a model_type to be specified")

for automodel_class_name in TASKS_TO_AUTOMODEL_CLASS_NAMES[task]:
for automodel_class_name in TASKS_TO_AUTO_MODEL_CLASS_NAMES[task]:
automodel_class = getattr(transformers, automodel_class_name)
if model_type in automodel_class._model_mapping._model_mapping:
return automodel_class
Expand Down
Loading

0 comments on commit b157b89

Please sign in to comment.