Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data Parallel support #89

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum_benchmark/backends/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC
from logging import getLogger
from dataclasses import dataclass
from logging import getLogger
from typing import Optional, TypeVar

from psutil import cpu_count
Expand Down
43 changes: 28 additions & 15 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import gc
import os
from logging import getLogger
from typing import TYPE_CHECKING, Any, Callable, Dict, List
from typing import Any, Callable, Dict, List

import torch

if torch.distributed.is_available():
import torch.distributed

if TYPE_CHECKING:
from datasets import Dataset
from transformers import TrainerCallback, TrainerState
from transformers.utils import ModelOutput
from datasets import Dataset
from transformers import TrainerCallback, TrainerState
from transformers.utils import ModelOutput

from ..base import Backend
from .config import PyTorchConfig
from .utils import DTYPES_MAPPING, randomize_weights, to_pow2
from .utils import DTYPES_MAPPING, TransformersDataParallel, randomize_weights

# bachend logger
LOGGER = getLogger("pytorch")
Expand Down Expand Up @@ -107,13 +102,19 @@ def configure(self, config: PyTorchConfig) -> None:
self.pretrained_model = get_peft_model(self.pretrained_model, peft_config=peft_config)

if self.config.deepspeed_inference:
LOGGER.info("\t+ Using DeepSpeed Inference")
LOGGER.info("\t+ Using DeepSpeed-Inference")
from deepspeed import init_inference

self.pretrained_model = init_inference(
self.pretrained_model, config=self.config.deepspeed_inference_config
self.pretrained_model,
config=self.config.deepspeed_inference_config,
dtype=self.torch_dtype if self.torch_dtype is not None else self.pretrained_model.dtype,
)

if self.config.data_parallel:
LOGGER.info("\t+ Using TransformersDataParallel")
self.pretrained_model = TransformersDataParallel(self.pretrained_model)

def load_model_from_pretrained(self) -> None:
# iniline quantization or quantization config modification
if self.config.quantization_scheme == "gptq":
Expand Down Expand Up @@ -160,7 +161,7 @@ def load_model_from_pretrained(self) -> None:
**self.automodel_kwargs,
**self.hub_kwargs,
)
elif hasattr(self.pretrained_config, "quantization_config") or self.quantization_config is not None:
elif self.is_quantized():
LOGGER.info(f"\t+ Loading quantized model and moving it to device: {self.device}")
self.pretrained_model = self.automodel_class.from_pretrained(
self.model,
Expand All @@ -179,12 +180,24 @@ def load_model_from_pretrained(self) -> None:
**self.hub_kwargs,
)

def is_quantized(self) -> bool:
return self.config.quantization_scheme is not None or hasattr(self.pretrained_config, "quantization_config")

def is_gptq_model(self) -> bool:
return self.config.quantization_scheme == "gptq" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq"
)

@property
def automodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if hasattr(self.pretrained_config, "quantization_config") or self.quantization_config is not None:
kwargs["low_cpu_mem_usage"] = True
if self.is_quantized():
if self.is_gptq_model():
kwargs["device_map"] = torch.device(self.device)
else:
kwargs["low_cpu_mem_usage"] = True

if self.quantization_config is not None:
kwargs["quantization_config"] = self.quantization_config
Expand Down
5 changes: 3 additions & 2 deletions optimum_benchmark/backends/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class PyTorchConfig(BackendConfig):
torch_dtype: Optional[str] = None

# inference options
disable_grad: bool = "${is_inference:${benchmark.name}}"
eval_mode: bool = "${is_inference:${benchmark.name}}"
disable_grad: bool = "${is_inference:${benchmark.name}}"

# automatic mixed precision options
amp_autocast: bool = False
Expand All @@ -63,7 +63,8 @@ class PyTorchConfig(BackendConfig):
quantization_scheme: Optional[str] = None
quantization_config: Dict[str, Any] = field(default_factory=dict)

# distributed options
# distributed inference options
data_parallel: bool = False
deepspeed_inference: bool = False
deepspeed_inference_config: Dict[str, Any] = field(default_factory=dict)

Expand Down
143 changes: 140 additions & 3 deletions optimum_benchmark/backends/pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import math
import threading
from itertools import chain
from typing import Any, Dict, List, Optional, Sequence, Union, cast

import torch
from torch._utils import ExceptionWrapper
from torch.cuda._utils import _get_device_index
from torch.cuda.amp import autocast
from torch.nn.modules import Module
from torch.nn.parallel.parallel_apply import get_a_var

DTYPES_MAPPING = {
"float32": "fp32",
Expand All @@ -22,5 +29,135 @@ def randomize_weights(model):
param.data.normal_(mean=0.0, std=0.2)


def to_pow2(x: int) -> int:
return 2 ** int(math.ceil(math.log2(x)))
# adapted from torch to use generate instead of forward
def parallel_generate_apply(
modules: Sequence[Module],
inputs: Sequence[Any],
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
) -> List[Any]:
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.

Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices

:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(
inputs
), f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = [_get_device_index(x, True) for x in devices]
streams = [torch.cuda.current_stream(x) for x in devices]
lock = threading.Lock()
results = {}
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()

def _worker(
i: int,
module: Module,
input: Any,
kwargs: Dict[str, Any],
device: Optional[Union[int, torch.device]] = None,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
torch.set_grad_enabled(grad_enabled)
if device is None:
t = get_a_var(input)
if t is None:
with lock:
results[i] = ExceptionWrapper(
where=f"in replica {i}, no device was provided and no tensor input was found; "
"device cannot be resolved"
)
return
device = t.get_device()
if stream is None:
stream = torch.cuda.current_stream(device)
try:
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = module.generate(*input, **kwargs)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(where=f"in replica {i} on device {device}")

if len(modules) > 1:
threads = [
threading.Thread(target=_worker, args=(i, module, input, kwargs, device, stream))
for i, (module, input, kwargs, device, stream) in enumerate(
zip(modules, inputs, kwargs_tup, devices, streams)
)
]

for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])

outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs


# adapted from torch to support generate
class TransformersDataParallel(torch.nn.DataParallel):
def generate(self, *inputs: Any, **kwargs: Any) -> Any:
with torch.autograd.profiler.record_function("DataParallel.generate"):
if not self.device_ids:
return self.module.generate(*inputs, **kwargs)

for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
"module must have its parameters and buffers "
f"on device {self.src_device_obj} (device_ids[0]) but found one of "
f"them on device: {t.device}"
)

inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
# for forward function without any inputs, empty list and dict will be created
# so the module can be executed on one device which is the first one in device_ids
if not inputs and not module_kwargs:
inputs = ((),)
module_kwargs = ({},)

if len(self.device_ids) == 1:
return self.module.generate(*inputs[0], **module_kwargs[0])

replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
outputs = self.parallel_generate_apply(replicas, inputs, module_kwargs)
return self.gather(outputs, self.output_device)

def parallel_generate_apply(self, replicas: Sequence, inputs: Sequence, kwargs: Any) -> List[Any]:
return parallel_generate_apply(replicas, inputs, kwargs, self.device_ids[: len(replicas)])

def __getattr__(self, name: str) -> Any:
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
7 changes: 4 additions & 3 deletions optimum_benchmark/benchmarks/inference/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def get_results_df(self) -> DataFrame:
return DataFrame(results_dict, index=[0])

def save(self) -> None:
LOGGER.info("Saving results")
results_df = self.get_results_df()
results_df.to_csv("inference_results.csv", index=False)
if os.environ.get("LOCAL_RANK", "0") == "0":
LOGGER.info("Saving results")
results_df = self.get_results_df()
results_df.to_csv("inference_results.csv", index=False)
Loading