Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 28, 2024
1 parent f85fea8 commit 7343ce8
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 12 deletions.
4 changes: 2 additions & 2 deletions optimum_benchmark/backends/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@


def get_diffusers_auto_pipeline_class_for_task(task: str):
task = map_from_synonym_task(task)

if not is_diffusers_available():
raise ImportError("diffusers is not available. Please, pip install diffusers.")

task = map_from_synonym_task(task)

if task not in TASKS_TO_AUTO_PIPELINE_CLASS_NAMES:
raise ValueError(f"Task {task} not supported for diffusers")

Expand Down
3 changes: 2 additions & 1 deletion optimum_benchmark/backends/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_model # type: ignore


def apply_peft(model: PreTrainedModel, peft_type: str, peft_config: Dict[str, Any]) -> PreTrainedModel:
def apply_peft(model: "PreTrainedModel", peft_type: str, peft_config: Dict[str, Any]) -> "PreTrainedModel":
if not is_peft_available():
raise ImportError("peft is not available. Please, pip install peft.")

peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type](**peft_config)

return get_peft_model(model=model, peft_config=peft_config)
8 changes: 2 additions & 6 deletions optimum_benchmark/backends/timm_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Any, Dict

from transformers import PretrainedConfig
Expand All @@ -17,7 +16,7 @@ def get_timm_model_creator():
return create_model


def get_timm_pretrained_config(model_name: str) -> PretrainedConfig:
def get_timm_pretrained_config(model_name: str) -> "PretrainedConfig":
if not is_timm_available():
raise ImportError("timm is not available. Please, pip install timm.")

Expand All @@ -31,7 +30,7 @@ def get_timm_pretrained_config(model_name: str) -> PretrainedConfig:
return get_pretrained_cfg(model_name)


def extract_timm_shapes_from_config(config: PretrainedConfig) -> Dict[str, Any]:
def extract_timm_shapes_from_config(config: "PretrainedConfig") -> Dict[str, Any]:
if not is_timm_available():
raise ImportError("timm is not available. Please, pip install timm.")

Expand Down Expand Up @@ -74,7 +73,4 @@ def extract_timm_shapes_from_config(config: PretrainedConfig) -> Dict[str, Any]:
shapes["height"] = input_size[1]
shapes["width"] = input_size[2]

if "num_classes" not in artifacts_dict:
warnings.warn("Could not extract shapes [num_channels, height, width] from timm model config.")

return shapes
6 changes: 3 additions & 3 deletions optimum_benchmark/backends/transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import transformers
from torch import Tensor
from transformers import (
AutoConfig,
AutoFeatureExtractor,
Expand Down Expand Up @@ -84,7 +85,7 @@ def get_flat_dict(d: Dict[str, Any]) -> Dict[str, Any]:
return flat_dict


def get_flat_artifact_dict(artifact: Union[PretrainedConfig, PretrainedProcessor]) -> Dict[str, Any]:
def get_flat_artifact_dict(artifact: Union["PretrainedConfig", "PretrainedProcessor"]) -> Dict[str, Any]:
artifact_dict = {}

if isinstance(artifact, ProcessorMixin):
Expand Down Expand Up @@ -175,7 +176,6 @@ def extract_transformers_shapes_from_artifacts(
shapes["num_queries"] = flat_artifacts_dict["num_queries"]

# image-text input

if "patch_size" in flat_artifacts_dict:
shapes["patch_size"] = flat_artifacts_dict["patch_size"]
if "in_chans" in flat_artifacts_dict:
Expand Down Expand Up @@ -212,7 +212,7 @@ def extract_transformers_shapes_from_artifacts(
}


def fast_random_tensor(tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
def fast_random_tensor(tensor: "Tensor", *args: Any, **kwargs: Any) -> "Tensor":
return torch.nn.init.uniform_(tensor)


Expand Down

0 comments on commit 7343ce8

Please sign in to comment.