Skip to content

Commit

Permalink
added data parallel support
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 24, 2023
1 parent 9a1b899 commit da67e3b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 23 deletions.
43 changes: 28 additions & 15 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import gc
import os
from logging import getLogger
from typing import TYPE_CHECKING, Any, Callable, Dict, List
from typing import Any, Callable, Dict, List

import torch

if torch.distributed.is_available():
import torch.distributed

if TYPE_CHECKING:
from datasets import Dataset
from transformers import TrainerCallback, TrainerState
from transformers.utils import ModelOutput
from datasets import Dataset
from transformers import TrainerCallback, TrainerState
from transformers.utils import ModelOutput

from ..base import Backend
from .config import PyTorchConfig
from .utils import DTYPES_MAPPING, randomize_weights, to_pow2
from .utils import DTYPES_MAPPING, randomize_weights, TransformersDataParallel

# bachend logger
LOGGER = getLogger("pytorch")
Expand Down Expand Up @@ -107,13 +102,19 @@ def configure(self, config: PyTorchConfig) -> None:
self.pretrained_model = get_peft_model(self.pretrained_model, peft_config=peft_config)

if self.config.deepspeed_inference:
LOGGER.info("\t+ Using DeepSpeed Inference")
LOGGER.info("\t+ Using DeepSpeed-Inference")
from deepspeed import init_inference

self.pretrained_model = init_inference(
self.pretrained_model, config=self.config.deepspeed_inference_config
self.pretrained_model,
config=self.config.deepspeed_inference_config,
dtype=self.torch_dtype,
)

if self.config.data_parallel:
LOGGER.info("\t+ Using TransformersDataParallel")
self.pretrained_model = TransformersDataParallel(self.pretrained_model)

def load_model_from_pretrained(self) -> None:
# iniline quantization or quantization config modification
if self.config.quantization_scheme == "gptq":
Expand Down Expand Up @@ -160,7 +161,7 @@ def load_model_from_pretrained(self) -> None:
**self.automodel_kwargs,
**self.hub_kwargs,
)
elif hasattr(self.pretrained_config, "quantization_config") or self.quantization_config is not None:
elif self.is_quantized():
LOGGER.info(f"\t+ Loading quantized model and moving it to device: {self.device}")
self.pretrained_model = self.automodel_class.from_pretrained(
self.model,
Expand All @@ -179,12 +180,24 @@ def load_model_from_pretrained(self) -> None:
**self.hub_kwargs,
)

def is_quantized(self) -> bool:
return self.config.quantization_scheme is not None or hasattr(self.pretrained_config, "quantization_config")

def is_gptq_model(self) -> bool:
return self.config.quantization_scheme == "gptq" or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq"
)

@property
def automodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if hasattr(self.pretrained_config, "quantization_config") or self.quantization_config is not None:
kwargs["low_cpu_mem_usage"] = True
if self.is_quantized():
if self.is_gptq_model():
kwargs["device_map"] = torch.device(self.device)
else:
kwargs["low_cpu_mem_usage"] = True

if self.quantization_config is not None:
kwargs["quantization_config"] = self.quantization_config
Expand Down
5 changes: 3 additions & 2 deletions optimum_benchmark/backends/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class PyTorchConfig(BackendConfig):
torch_dtype: Optional[str] = None

# inference options
disable_grad: bool = "${is_inference:${benchmark.name}}"
eval_mode: bool = "${is_inference:${benchmark.name}}"
disable_grad: bool = "${is_inference:${benchmark.name}}"

# automatic mixed precision options
amp_autocast: bool = False
Expand All @@ -63,7 +63,8 @@ class PyTorchConfig(BackendConfig):
quantization_scheme: Optional[str] = None
quantization_config: Dict[str, Any] = field(default_factory=dict)

# distributed options
# distributed inference options
data_parallel: bool = False
deepspeed_inference: bool = False
deepspeed_inference_config: Dict[str, Any] = field(default_factory=dict)

Expand Down
143 changes: 140 additions & 3 deletions optimum_benchmark/backends/pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import math
import threading
from itertools import chain
from typing import Any, Dict, List, Optional, Sequence, Union, cast

import torch
from torch._utils import ExceptionWrapper
from torch.cuda._utils import _get_device_index
from torch.cuda.amp import autocast
from torch.nn.modules import Module
from torch.nn.parallel.parallel_apply import get_a_var

DTYPES_MAPPING = {
"float32": "fp32",
Expand All @@ -22,5 +29,135 @@ def randomize_weights(model):
param.data.normal_(mean=0.0, std=0.2)


def to_pow2(x: int) -> int:
return 2 ** int(math.ceil(math.log2(x)))
# adapted from torch to use generate instead of forward
def parallel_generate_apply(
modules: Sequence[Module],
inputs: Sequence[Any],
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
) -> List[Any]:
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(
inputs
), f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = [_get_device_index(x, True) for x in devices]
streams = [torch.cuda.current_stream(x) for x in devices]
lock = threading.Lock()
results = {}
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()

def _worker(
i: int,
module: Module,
input: Any,
kwargs: Dict[str, Any],
device: Optional[Union[int, torch.device]] = None,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
torch.set_grad_enabled(grad_enabled)
if device is None:
t = get_a_var(input)
if t is None:
with lock:
results[i] = ExceptionWrapper(
where=f"in replica {i}, no device was provided and no tensor input was found; "
"device cannot be resolved"
)
return
device = t.get_device()
if stream is None:
stream = torch.cuda.current_stream(device)
try:
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = module.generate(*input, **kwargs)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(where=f"in replica {i} on device {device}")

if len(modules) > 1:
threads = [
threading.Thread(target=_worker, args=(i, module, input, kwargs, device, stream))
for i, (module, input, kwargs, device, stream) in enumerate(
zip(modules, inputs, kwargs_tup, devices, streams)
)
]

for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])

outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs


# adapted from torch to support generate
class TransformersDataParallel(torch.nn.DataParallel):
def generate(self, *inputs: Any, **kwargs: Any) -> Any:
with torch.autograd.profiler.record_function("DataParallel.generate"):
if not self.device_ids:
return self.module.generate(*inputs, **kwargs)

for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
"module must have its parameters and buffers "
f"on device {self.src_device_obj} (device_ids[0]) but found one of "
f"them on device: {t.device}"
)

inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
# for forward function without any inputs, empty list and dict will be created
# so the module can be executed on one device which is the first one in device_ids
if not inputs and not module_kwargs:
inputs = ((),)
module_kwargs = ({},)

if len(self.device_ids) == 1:
return self.module.generate(*inputs[0], **module_kwargs[0])

replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
outputs = self.parallel_generate_apply(replicas, inputs, module_kwargs)
return self.gather(outputs, self.output_device)

def parallel_generate_apply(self, replicas: Sequence, inputs: Sequence, kwargs: Any) -> List[Any]:
return parallel_generate_apply(replicas, inputs, kwargs, self.device_ids[: len(replicas)])

def __getattr__(self, name: str) -> Any:
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
7 changes: 4 additions & 3 deletions optimum_benchmark/benchmarks/inference/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def get_results_df(self) -> DataFrame:
return DataFrame(results_dict, index=[0])

def save(self) -> None:
LOGGER.info("Saving results")
results_df = self.get_results_df()
results_df.to_csv("inference_results.csv", index=False)
if os.environ.get("LOCAL_RANK", "0") == "0":
LOGGER.info("Saving results")
results_df = self.get_results_df()
results_df.to_csv("inference_results.csv", index=False)

0 comments on commit da67e3b

Please sign in to comment.