Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove DP vs TP distinction and simplify aggregation across processes #299

Merged
merged 17 commits into from
Nov 28, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Optimum-Benchmark is continuously and intensively tested on a variety of devices
[![CLI_CUDA_TENSORRT_LLM](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_tensorrt_llm.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_tensorrt_llm.yaml)
[![CLI_CUDA_TORCH_ORT](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_torch_ort.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_torch_ort.yaml)
[![CLI_CUDA_VLLM](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_vllm.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_cuda_vllm.yaml)
[![CLI_ENERGY_STAR](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_energy_star.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_energy_star.yaml)
[![CLI_MISC](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_misc.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_misc.yaml)
[![CLI_ROCM_PYTORCH](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_rocm_pytorch.yaml/badge.svg)](https://github.com/huggingface/optimum-benchmark/actions/workflows/test_cli_rocm_pytorch.yaml)

Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"torch_dtype": "bfloat16",
"quantization_scheme": "torchao",
"quantization_config": {"quant_type": "int4_weight_only", "group_size": 128},
}
},
}


Expand Down
12 changes: 6 additions & 6 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from .config import BackendConfigT
from .diffusers_utils import (
extract_diffusers_shapes_from_model,
get_diffusers_automodel_loader_for_task,
get_diffusers_auto_pipeline_class_for_task,
get_diffusers_pretrained_config,
)
from .timm_utils import extract_timm_shapes_from_config, get_timm_automodel_loader, get_timm_pretrained_config
from .timm_utils import extract_timm_shapes_from_config, get_timm_model_creator, get_timm_pretrained_config
from .transformers_utils import (
PretrainedProcessor,
extract_transformers_shapes_from_artifacts,
get_transformers_automodel_loader_for_task,
get_transformers_auto_model_class_for_task,
get_transformers_generation_config,
get_transformers_pretrained_config,
get_transformers_pretrained_processor,
Expand Down Expand Up @@ -56,15 +56,15 @@ def __init__(self, config: BackendConfigT):
self.logger.info("\t+ Benchmarking a Diffusers pipeline")
self.pretrained_config = get_diffusers_pretrained_config(self.config.model, **self.config.model_kwargs)
self.model_shapes = extract_diffusers_shapes_from_model(self.config.model, **self.config.model_kwargs)
self.automodel_loader = get_diffusers_automodel_loader_for_task(self.config.task)
self.automodel_loader = get_diffusers_auto_pipeline_class_for_task(self.config.task)
self.pretrained_processor = None
self.generation_config = None

elif self.config.library == "timm":
self.logger.info("\t+ Benchmarking a Timm model")
self.pretrained_config = get_timm_pretrained_config(self.config.model)
self.model_shapes = extract_timm_shapes_from_config(self.pretrained_config)
self.automodel_loader = get_timm_automodel_loader()
self.automodel_loader = get_timm_model_creator()
self.pretrained_processor = None
self.generation_config = None

Expand All @@ -78,7 +78,7 @@ def __init__(self, config: BackendConfigT):

else:
self.logger.info("\t+ Benchmarking a Transformers model")
self.automodel_loader = get_transformers_automodel_loader_for_task(self.config.task, self.config.model_type)
self.automodel_loader = get_transformers_auto_model_class_for_task(self.config.task, self.config.model_type)
self.generation_config = get_transformers_generation_config(self.config.model, **self.config.model_kwargs)
self.pretrained_config = get_transformers_pretrained_config(self.config.model, **self.config.model_kwargs)
self.pretrained_processor = get_transformers_pretrained_processor(
Expand Down
11 changes: 5 additions & 6 deletions optimum_benchmark/backends/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,25 @@ def __post_init__(self):
# TODO: add cache_dir, token, etc. to these methods
if self.library is None:
self.library = infer_library_from_model_name_or_path(
self.model,
model_name_or_path=self.model,
token=self.model_kwargs.get("token", None),
revision=self.model_kwargs.get("revision", None),
)

if self.task is None:
self.task = infer_task_from_model_name_or_path(
self.model,
self.library,
model_name_or_path=self.model,
token=self.model_kwargs.get("token", None),
revision=self.model_kwargs.get("revision", None),
library_name=self.library,
)

if self.model_type is None:
self.model_type = infer_model_type_from_model_name_or_path(
self.model,
self.library,
model_name_or_path=self.model,
token=self.model_kwargs.get("token", None),
revision=self.model_kwargs.get("revision", None),
trust_remote_code=self.model_kwargs.get("trust_remote_code", False),
library_name=self.library,
)

if self.device is None:
Expand Down
46 changes: 12 additions & 34 deletions optimum_benchmark/backends/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,25 @@
from hydra.utils import get_class

from ..import_utils import is_diffusers_available
from ..task_utils import TASKS_TO_AUTO_PIPELINE_CLASS_NAMES, map_from_synonym_task

if is_diffusers_available():
import diffusers
from diffusers import DiffusionPipeline

if hasattr(diffusers, "pipelines") and hasattr(diffusers.pipelines, "auto_pipeline"):
from diffusers.pipelines.auto_pipeline import (
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
)

TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {
"inpainting": AUTO_INPAINT_PIPELINES_MAPPING.copy(),
"text-to-image": AUTO_TEXT2IMAGE_PIPELINES_MAPPING.copy(),
"image-to-image": AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.copy(),
}

for task_name, model_mapping in TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES.items():
for model_type, model_class in model_mapping.items():
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name][model_type] = model_class.__name__
else:
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {}
else:
TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {}

def get_diffusers_auto_pipeline_class_for_task(task: str):
task = map_from_synonym_task(task)

if not is_diffusers_available():
raise ImportError("diffusers is not available. Please, pip install diffusers.")

if task not in TASKS_TO_AUTO_PIPELINE_CLASS_NAMES:
raise ValueError(f"Task {task} not supported for diffusers")

model_loader_name = TASKS_TO_AUTO_PIPELINE_CLASS_NAMES[task]

TASKS_TO_MODEL_LOADERS = {
"inpainting": "AutoPipelineForInpainting",
"text-to-image": "AutoPipelineForText2Image",
"image-to-image": "AutoPipelineForImage2Image",
}
return getattr(diffusers, model_loader_name)


def get_diffusers_pretrained_config(model: str, **kwargs) -> Dict[str, int]:
Expand Down Expand Up @@ -85,12 +72,3 @@ def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]:
shapes["width"] = -1

return shapes


def get_diffusers_automodel_loader_for_task(task: str):
if not is_diffusers_available():
raise ImportError("diffusers is not available. Please, pip install diffusers.")

model_loader_name = TASKS_TO_MODEL_LOADERS[task]
model_loader_class = getattr(diffusers, model_loader_name)
return model_loader_class
21 changes: 2 additions & 19 deletions optimum_benchmark/backends/ipex/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,14 @@ def automodel_kwargs(self) -> Dict[str, Any]:
if self.config.torch_dtype is not None:
kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)

print(kwargs)

return kwargs

@property
def is_dp_distributed(self) -> bool:
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0:
raise ValueError(
f"Batch size {input_shapes['batch_size']} must be divisible by "
f"data parallel world size {torch.distributed.get_world_size()}"
)
# distributing batch size across processes
input_shapes["batch_size"] //= torch.distributed.get_world_size()

# registering input shapes for usage during model reshaping
self.input_shapes = input_shapes

return input_shapes

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

Expand Down
16 changes: 4 additions & 12 deletions optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,12 @@ def quantize_onnx_files(self) -> None:
if self.pretrained_config is not None:
self.pretrained_config.save_pretrained(self.quantized_model)

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0:
raise ValueError(
f"Batch size {input_shapes['batch_size']} must be divisible by "
f"data parallel world size {torch.distributed.get_world_size()}"
)
# distributing batch size across processes
input_shapes["batch_size"] //= torch.distributed.get_world_size()

return input_shapes
@property
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

Expand Down
49 changes: 21 additions & 28 deletions optimum_benchmark/backends/openvino/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def load(self) -> None:
if self.config.reshape:
static_shapes = {
key: value
for key, value in {**self.input_shapes, **self.model_shapes}.items()
for key, value in self.model_shapes.items()
if key in inspect.getfullargspec(self.pretrained_model.reshape).args
}
if ("sequence_length" in static_shapes) and ("height" in static_shapes) and ("width" in static_shapes):
Expand Down Expand Up @@ -135,20 +135,6 @@ def _load_ovmodel_with_no_weights(self) -> None:
self.config.export = original_export
self.config.model = original_model

@property
def is_dp_distributed(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

@property
def ovmodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if self.config.task in TEXT_GENERATION_TASKS:
kwargs["use_cache"] = self.config.use_cache
kwargs["use_merged"] = self.config.use_merged

return kwargs

def quantize_automodel(self) -> None:
self.logger.info("\t+ Attempting quantization")
self.quantized_model = f"{self.tmpdir.name}/quantized_model"
Expand Down Expand Up @@ -181,30 +167,37 @@ def quantize_automodel(self) -> None:
batch_size=1,
)

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0:
raise ValueError(
f"Batch size {input_shapes['batch_size']} must be divisible by "
f"data parallel world size {torch.distributed.get_world_size()}"
)
# distributing batch size across processes
input_shapes["batch_size"] //= torch.distributed.get_world_size()
@property
def ovmodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

# registering input shapes for usage during model reshaping
self.input_shapes = input_shapes
if self.config.task in TEXT_GENERATION_TASKS:
kwargs["use_cache"] = self.config.use_cache
kwargs["use_merged"] = self.config.use_merged

return input_shapes
return kwargs

@property
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

for key in list(inputs.keys()):
if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names:
inputs.pop(key)

if "input_ids" in inputs:
self.model_shapes.update(dict(zip(["batch_size", "sequence_length"], inputs["input_ids"].shape)))

if "pixel_values" in inputs:
self.model_shapes.update(
dict(zip(["batch_size", "num_channels", "height", "width"], inputs["pixel_values"].shape))
)

return inputs

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
Expand Down
3 changes: 2 additions & 1 deletion optimum_benchmark/backends/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_model # type: ignore


def apply_peft(model: PreTrainedModel, peft_type: str, peft_config: Dict[str, Any]) -> PreTrainedModel:
def apply_peft(model: "PreTrainedModel", peft_type: str, peft_config: Dict[str, Any]) -> "PreTrainedModel":
if not is_peft_available():
raise ImportError("peft is not available. Please, pip install peft.")

peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type](**peft_config)

return get_peft_model(model=model, peft_config=peft_config)
47 changes: 13 additions & 34 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
TorchAoConfig,
Trainer,
TrainerCallback,
TrainerState,
TrainingArguments,
)
from transformers import TorchAoConfig

from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available
from ..base import Backend
Expand All @@ -28,7 +28,7 @@
import deepspeed # type: ignore

if is_torch_distributed_available():
import torch.distributed
import torch.distributed # type: ignore

if is_zentorch_available():
import zentorch # type: ignore # noqa: F401
Expand Down Expand Up @@ -332,18 +332,6 @@ def process_quantization_config(self) -> None:
else:
raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized")

@property
def is_distributed(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

@property
def is_tp_distributed(self) -> bool:
return self.is_distributed and self.config.deepspeed_inference

@property
def is_dp_distributed(self) -> bool:
return self.is_distributed and not self.config.deepspeed_inference

@property
def is_quantized(self) -> bool:
return self.config.quantization_scheme is not None or (
Expand Down Expand Up @@ -420,35 +408,26 @@ def automodel_kwargs(self) -> Dict[str, Any]:

return kwargs

def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0:
raise ValueError(
f"Batch size {input_shapes['batch_size']} must be divisible by "
f"data parallel world size {torch.distributed.get_world_size()}"
)
# distributing batch size across processes
input_shapes["batch_size"] //= torch.distributed.get_world_size()

if self.is_tp_distributed:
if torch.distributed.get_rank() != 0:
# zeroing throughput on other ranks
input_shapes["batch_size"] = 0

return input_shapes
@property
def split_between_processes(self) -> bool:
return (
is_torch_distributed_available()
and torch.distributed.is_initialized()
and not self.config.deepspeed_inference
)

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.is_dp_distributed:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

if self.config.library == "timm":
inputs = {"x": inputs["pixel_values"]}

for key, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[key] = value.to(self.config.device)

if self.config.library == "timm":
inputs = {"x": inputs["pixel_values"]}

return inputs

@torch.inference_mode()
Expand Down
Loading