From 10e4eceb0378a56ef92493c55212c11b599e3f36 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 29 Apr 2024 09:22:17 +0200 Subject: [PATCH] Fix isolation (#186) --- .../test_cli_rocm_pytorch_single_gpu.yaml | 6 +- Makefile | 1 - examples/pytorch_bert.yaml | 1 + examples/pytorch_llama.yaml | 7 +- examples/pytorch_timm.yaml | 7 +- examples/trt_llama.yaml | 4 + optimum_benchmark/backends/base.py | 5 +- optimum_benchmark/backends/pytorch/backend.py | 217 +++++++++--------- optimum_benchmark/backends/pytorch/config.py | 17 +- .../backends/torch_ort/backend.py | 7 +- optimum_benchmark/benchmarks/report.py | 8 +- optimum_benchmark/experiment.py | 29 +-- optimum_benchmark/launchers/config.py | 9 +- .../launchers/device_isolation_utils.py | 214 +++++++++++++++++ optimum_benchmark/launchers/inline/config.py | 10 + .../launchers/inline/launcher.py | 14 +- .../launchers/isolation_utils.py | 178 -------------- .../launchers/process/launcher.py | 45 ++-- .../launchers/torchrun/config.py | 15 +- .../launchers/torchrun/launcher.py | 95 ++++---- optimum_benchmark/trackers/energy.py | 24 +- optimum_benchmark/trackers/latency.py | 68 +++--- optimum_benchmark/trackers/memory.py | 44 ++-- tests/configs/_base_.yaml | 2 + .../cuda_inference_pytorch_text_decoders.yaml | 2 +- .../cuda_inference_pytorch_text_encoders.yaml | 2 +- tests/test_api.py | 27 +-- tests/test_cli.py | 17 +- 28 files changed, 570 insertions(+), 505 deletions(-) create mode 100644 optimum_benchmark/launchers/device_isolation_utils.py delete mode 100644 optimum_benchmark/launchers/isolation_utils.py diff --git a/.github/workflows/test_cli_rocm_pytorch_single_gpu.yaml b/.github/workflows/test_cli_rocm_pytorch_single_gpu.yaml index a3bcd86a..f8ecb9a4 100644 --- a/.github/workflows/test_cli_rocm_pytorch_single_gpu.yaml +++ b/.github/workflows/test_cli_rocm_pytorch_single_gpu.yaml @@ -26,10 +26,8 @@ jobs: - name: Target devices run: | - echo "DEVICE0: $DEVICE0" - echo "DEVICE1: $DEVICE1" - echo "DEVICE0=$DEVICE0" >> $GITHUB_ENV - echo "DEVICE1=$DEVICE1" >> $GITHUB_ENV + echo "DEVICE: $DEVICE" + echo "DEVICE=$DEVICE" >> $GITHUB_ENV - name: Build image run: docker build diff --git a/Makefile b/Makefile index 0881a6c1..f9635dcc 100644 --- a/Makefile +++ b/Makefile @@ -70,7 +70,6 @@ run_rocm_container: docker run \ -it \ --rm \ - --pid host \ --shm-size 64G \ --device /dev/kfd \ --device /dev/dri \ diff --git a/examples/pytorch_bert.yaml b/examples/pytorch_bert.yaml index af98cdc8..3f566481 100644 --- a/examples/pytorch_bert.yaml +++ b/examples/pytorch_bert.yaml @@ -11,6 +11,7 @@ experiment_name: pytorch_bert launcher: device_isolation: true + device_isolation_action: warn benchmark: latency: true diff --git a/examples/pytorch_llama.yaml b/examples/pytorch_llama.yaml index 91bd5438..133ba54c 100644 --- a/examples/pytorch_llama.yaml +++ b/examples/pytorch_llama.yaml @@ -9,15 +9,16 @@ defaults: experiment_name: pytorch_llama +launcher: + device_isolation: true + device_isolation_action: warn + backend: device: cuda device_ids: 0 no_weights: true model: TheBloke/Llama-2-70B-AWQ -launcher: - device_isolation: true - benchmark: input_shapes: batch_size: 1 diff --git a/examples/pytorch_timm.yaml b/examples/pytorch_timm.yaml index 6091da15..619d4861 100644 --- a/examples/pytorch_timm.yaml +++ b/examples/pytorch_timm.yaml @@ -9,14 +9,15 @@ defaults: experiment_name: pytorch_timm +launcher: + device_isolation: true + device_isolation_action: warn + backend: device: cuda device_ids: 0 model: timm/mobilenetv3_large_100.ra_in1k -launcher: - device_isolation: true - benchmark: memory: true input_shapes: diff --git a/examples/trt_llama.yaml b/examples/trt_llama.yaml index 0d3a1d87..7eab5c67 100644 --- a/examples/trt_llama.yaml +++ b/examples/trt_llama.yaml @@ -9,6 +9,10 @@ defaults: experiment_name: trt_llama +launcher: + device_isolation: true + device_isolation_action: warn + backend: device: cuda device_ids: 0 diff --git a/optimum_benchmark/backends/base.py b/optimum_benchmark/backends/base.py index 6a8053a6..52e5ae19 100644 --- a/optimum_benchmark/backends/base.py +++ b/optimum_benchmark/backends/base.py @@ -120,11 +120,8 @@ def train(self, **kwargs) -> TrainerState: """ raise NotImplementedError("Backend must implement train method") - def delete_pretrained_model(self) -> None: + def clean(self) -> None: if hasattr(self, "pretrained_model"): del self.pretrained_model - def clean(self) -> None: - LOGGER.info(f"Cleaning {self.NAME} backend") - self.delete_pretrained_model() gc.collect() diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 0314844e..1f829a5c 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -18,17 +18,15 @@ TrainingArguments, ) -from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available +from ...import_utils import is_deepspeed_available, is_zentorch_available from ..base import Backend from ..peft_utils import apply_peft from ..transformers_utils import random_init_weights from .config import PyTorchConfig if is_deepspeed_available(): - from deepspeed import init_inference + import deepspeed -if is_torch_distributed_available(): - import torch.distributed if is_zentorch_available(): import zentorch # type: ignore # noqa: F401 @@ -45,7 +43,15 @@ def __init__(self, config: PyTorchConfig): super().__init__(config) self.validate_library() - # Thread settings + if self.config.deepspeed_inference and self.is_quantized: + raise ValueError("Deepspeed-Inference is not compatible with Transformers quantization") + + # Quantization + if self.is_quantized: + LOGGER.info("\t+ Processing quantization config") + self.process_quantization_config() + + # Threads if self.config.inter_op_num_threads is not None: LOGGER.info(f"\t+ Setting pytorch inter_op_num_threads({self.config.inter_op_num_threads}))") torch.set_num_threads(self.config.inter_op_num_threads) @@ -53,27 +59,25 @@ def __init__(self, config: PyTorchConfig): LOGGER.info(f"\t+ Setting pytorch intra_op_num_threads({self.config.intra_op_num_threads}))") torch.set_num_interop_threads(self.config.intra_op_num_threads) - # Mixed precision - if self.config.amp_dtype: - LOGGER.info(f"\t+ Setting mixed precision dtype to {self.config.amp_dtype}") - self.amp_dtype = getattr(torch, self.config.amp_dtype) - else: - self.amp_dtype = None - - # Quantization - if self.is_quantized: - LOGGER.info("\t+ Processing quantization config") - self.process_quantization_config() - else: - self.quantization_config = None - - if self.config.deepspeed_inference: - if self.quantization_config is not None: - raise ValueError("Deepspeed-Inference is not compatible with Transformers quantization") + # Autocast + if self.config.autocast_enabled: + LOGGER.info("\t+ Enabling automatic mixed precision") + torch.set_autocast_enabled(True) + + if self.config.autocast_dtype is not None: + if self.config.device == "cpu": + LOGGER.info(f"\t+ Setting autocast cpu dtype to {self.config.autocast_dtype}") + torch.set_autocast_cpu_dtype(getattr(torch, self.config.autocast_dtype)) + elif self.config.device == "cuda": + LOGGER.info(f"\t+ Setting autocast gpu dtype to {self.config.autocast_dtype}") + torch.set_autocast_gpu_dtype(getattr(torch, self.config.autocast_dtype)) + else: + raise ValueError(f"Device {self.config.device} not supported for autocast") LOGGER.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() + # Model if self.config.no_weights and (self.config.library == "diffusers" or self.config.library == "timm"): raise ValueError("Diffusion pipelines and Timm models don't support no weights") elif self.config.no_weights: @@ -83,6 +87,9 @@ def __init__(self, config: PyTorchConfig): LOGGER.info("\t+ Loading model with pretrained weights") self.load_model_from_pretrained() + self.tmpdir.cleanup() + + # KV-Cache if self.config.cache_implementation is not None: LOGGER.info(f"\t+ Setting cache implementation to {self.config.cache_implementation}") self.pretrained_model.generation_config.cache_implementation = self.config.cache_implementation @@ -97,14 +104,15 @@ def __init__(self, config: PyTorchConfig): LOGGER.info("\t+ Enabling BetterTransformer") self.pretrained_model.to_bettertransformer() + # PEFT + if self.config.peft_type is not None: + LOGGER.info("\t+ Applying PEFT") + self.pretrained_model = apply_peft(self.pretrained_model, self.config.peft_type, self.config.peft_config) + # Torch compile if self.config.torch_compile: - if self.config.device == "cuda" and torch.cuda.get_device_capability(0)[0] >= 8: - LOGGER.info("\t+ Setting float32_matmul_precision to high") - torch.set_float32_matmul_precision("high") - if self.config.library == "diffusers": - LOGGER.info("\t+ Using torch.compile to compile unet and vae") + LOGGER.info("\t+ Using torch.compile on unet and vae") self.pretrained_model.unet = torch.compile( self.pretrained_model.unet, **self.config.torch_compile_config ) @@ -112,24 +120,23 @@ def __init__(self, config: PyTorchConfig): self.pretrained_model.vae.decode, **self.config.torch_compile_config ) else: - LOGGER.info("\t+ Using torch.compile on forward pass") - self.pretrained_model.forward = torch.compile( - self.pretrained_model.forward, **self.config.torch_compile_config - ) - - if self.config.peft_type is not None: - LOGGER.info("\t+ Applying PEFT") - self.pretrained_model = apply_peft(self.pretrained_model, self.config.peft_type, self.config.peft_config) + LOGGER.info("\t+ Using torch.compile on model") + self.pretrained_model = torch.compile(self.pretrained_model, **self.config.torch_compile_config) - self.tmpdir.cleanup() + # DeepSpeed + if self.config.deepspeed_inference: + LOGGER.info("\t+ Initializing DeepSpeed Inference Engine") + self.pretrained_model = deepspeed.init_inference( + model=self.pretrained_model, config=self.config.deepspeed_inference_config + ) def validate_library(self) -> None: if self.config.library == "timm": - LOGGER.info(f"\t+ Using Timm method {self.automodel_class.__name__}") + LOGGER.info(f"\t+ Using Timm's {self.automodel_class.__name__}") elif self.config.library == "diffusers": - LOGGER.info(f"\t+ Using Pipeline class {self.automodel_class.__name__}") + LOGGER.info(f"\t+ Using Diffusers Pipeline {self.automodel_class.__name__}") elif self.config.library == "transformers": - LOGGER.info(f"\t+ Using AutoModel class {self.automodel_class.__name__}") + LOGGER.info(f"\t+ Using AutoModel {self.automodel_class.__name__}") else: raise ValueError(f"Library {self.config.library} not supported") @@ -140,6 +147,7 @@ def load_model_from_pretrained(self) -> None: if self.config.device != "cpu": LOGGER.info(f"\t+ Moving model to device: {self.config.device}") self.pretrained_model.to(self.config.device) + elif self.config.library == "diffusers": LOGGER.info("\t+ Loading Diffusion pipeline") self.pretrained_model = self.automodel_class.from_pretrained( @@ -152,56 +160,36 @@ def load_model_from_pretrained(self) -> None: if self.config.device_map is None and self.config.device != "cpu": LOGGER.info(f"\t+ Moving pipeline to device: {self.config.device}") self.pretrained_model.to(self.config.device) - elif self.config.deepspeed_inference: - if self.config.no_weights: - with torch.device("meta"): - LOGGER.info("\t+ Loading model on meta device for fast initialization") - self.pretrained_model = self.automodel_class.from_pretrained( - pretrained_model_name_or_path=self.config.model, - **self.config.hub_kwargs, - **self.automodel_kwargs, - ) - LOGGER.info("\t+ Materializing model on CPU") - self.pretrained_model.to_empty(device="cpu") - LOGGER.info("\t+ Tying model weights") - self.pretrained_model.tie_weights() - else: - LOGGER.info("\t+ Loading model on cpu to avoid OOM") - with torch.device("cpu"): - self.pretrained_model = self.automodel_class.from_pretrained( - pretrained_model_name_or_path=self.config.model, - **self.config.hub_kwargs, - **self.automodel_kwargs, - ) - - torch.distributed.barrier() # better safe than hanging - LOGGER.info("\t+ Initializing DeepSpeed Inference Engine") - self.pretrained_model = init_inference(self.pretrained_model, config=self.config.deepspeed_inference_config) - torch.distributed.barrier() # better safe than hanging + elif self.is_quantized: - # we can't use device context manager on quantized models - LOGGER.info("\t+ Loading Quantized model") + LOGGER.info(f"\t+ Loading {self.quantization_config.quant_method}-quantized model") self.pretrained_model = self.automodel_class.from_pretrained( pretrained_model_name_or_path=self.config.model, device_map=self.config.device_map or torch.device(self.config.device), + # quantized models are more compatible with device_map dispatcher than (to(device)) + # using to(device) on quantized models sometimes leaves some layers on cpu or raises + # an error because the layers are already on the device **self.config.hub_kwargs, **self.automodel_kwargs, ) + elif self.config.device_map is not None: - # we can't use device context manager since device_map is specified - LOGGER.info(f"\t+ Loading model with device map: {self.config.device_map}") + LOGGER.info(f"\t+ Loading Transformers model with device map: {self.config.device_map}") self.pretrained_model = self.automodel_class.from_pretrained( pretrained_model_name_or_path=self.config.model, device_map=self.config.device_map, **self.config.hub_kwargs, **self.automodel_kwargs, ) + else: - LOGGER.info(f"\t+ Loading model directly on device: {self.config.device}") - with torch.device(self.config.device): - self.pretrained_model = self.automodel_class.from_pretrained( - pretrained_model_name_or_path=self.config.model, **self.config.hub_kwargs, **self.automodel_kwargs - ) + LOGGER.info("\t+ Loading Transformers model") + self.pretrained_model = self.automodel_class.from_pretrained( + pretrained_model_name_or_path=self.config.model, **self.config.hub_kwargs, **self.automodel_kwargs + ) + if self.config.device != "cpu": + LOGGER.info(f"\t+ Moving model to device: {self.config.device}") + self.pretrained_model.to(self.config.device) def create_no_weights_model(self) -> None: if self.pretrained_config is None: @@ -231,18 +219,32 @@ def create_no_weights_model(self) -> None: # tricking from_pretrained to load the model as if it was quantized LOGGER.info("\t+ Saving no weights model pretrained config") - if self.config.library == "transformers": - self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) + self.pretrained_config.save_pretrained(save_directory=self.no_weights_model) def load_model_with_no_weights(self) -> None: LOGGER.info("\t+ Creating no weights model") self.create_no_weights_model() - with random_init_weights(): - original_model, self.config.model = self.config.model, self.no_weights_model - LOGGER.info("\t+ Loading no weights AutoModel") - self.load_model_from_pretrained() - self.config.model = original_model + if self.config.deepspeed_inference: + with torch.device("meta"): + # with big models, loading no_weights_model is very slow (randomizing every weight) + # so we load the model on meta device to speed up the process and then move it to cpu + LOGGER.info("\t+ Loading Transformers model on meta device for fast initialization") + self.pretrained_model = self.automodel_class.from_pretrained( + pretrained_model_name_or_path=self.config.model, + **self.config.hub_kwargs, + **self.automodel_kwargs, + ) + LOGGER.info("\t+ Materializing meta model on CPU to avoid OOM") + self.pretrained_model.to_empty(device="cpu") + LOGGER.info("\t+ Tying model weights") + self.pretrained_model.tie_weights() + else: + with random_init_weights(): + original_model, self.config.model = self.config.model, self.no_weights_model + LOGGER.info("\t+ Loading no weights AutoModel") + self.load_model_from_pretrained() + self.config.model = original_model def process_quantization_config(self) -> None: if self.is_gptq_quantized: @@ -265,7 +267,10 @@ def process_quantization_config(self) -> None: @property def is_quantized(self) -> bool: - return self.config.quantization_scheme is not None or hasattr(self.pretrained_config, "quantization_config") + return self.config.quantization_scheme is not None or ( + hasattr(self.pretrained_config, "quantization_config") + and self.pretrained_config.quantization_config.get("quant_method", None) is not None + ) @property def is_bnb_quantized(self) -> bool: @@ -290,15 +295,19 @@ def is_awq_quantized(self) -> bool: @property def is_exllamav2(self) -> bool: - return (self.is_gptq_quantized or self.is_awq_quantized) and ( - ( - hasattr(self.pretrained_config, "quantization_config") - and hasattr(self.pretrained_config.quantization_config, "exllama_config") - and self.pretrained_config.quantization_config.exllama_config.get("version", None) == 2 - ) - or ( - "exllama_config" in self.config.quantization_config - and self.config.quantization_config["exllama_config"].get("version", None) == 2 + return ( + self.is_quantized + and (self.is_gptq_quantized or self.is_awq_quantized) + and ( + ( + hasattr(self.pretrained_config, "quantization_config") + and hasattr(self.pretrained_config.quantization_config, "exllama_config") + and self.pretrained_config.quantization_config.exllama_config.get("version", None) == 2 + ) + or ( + "exllama_config" in self.config.quantization_config + and self.config.quantization_config["exllama_config"].get("version", None) == 2 + ) ) ) @@ -306,12 +315,12 @@ def is_exllamav2(self) -> bool: def automodel_kwargs(self) -> Dict[str, Any]: kwargs = {} - if self.is_quantized: - kwargs["quantization_config"] = self.quantization_config - if self.config.torch_dtype is not None: kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) + if self.is_quantized: + kwargs["quantization_config"] = self.quantization_config + if self.config.attn_implementation is not None: kwargs["attn_implementation"] = self.config.attn_implementation @@ -339,18 +348,15 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: @torch.inference_mode() def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - with torch.autocast(device_type=self.config.device, dtype=self.amp_dtype, enabled=self.config.amp_autocast): - return self.pretrained_model.forward(**inputs, **kwargs) + return self.pretrained_model.forward(**inputs, **kwargs) @torch.inference_mode() def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - with torch.autocast(device_type=self.config.device, dtype=self.amp_dtype, enabled=self.config.amp_autocast): - return self.pretrained_model.generate(**inputs, **kwargs) + return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode() def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - with torch.autocast(device_type=self.config.device, dtype=self.amp_dtype, enabled=self.config.amp_autocast): - return self.pretrained_model.generate(**inputs, **kwargs) + return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode() def call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: @@ -383,10 +389,15 @@ def seed(self): torch.cuda.manual_seed_all(self.config.seed) def clean(self) -> None: - super().clean() - if hasattr(self, "tmpdir"): LOGGER.info("\t+ Cleaning backend temporary directory") self.tmpdir.cleanup() - gc.collect() + if hasattr(self, "pretrained_model"): + LOGGER.info("\t+ Deleting pretrained model") + del self.pretrained_model + gc.collect() + + if self.config.device == "cuda": + LOGGER.info("\t+ Emptying CUDA cache") + torch.cuda.empty_cache() diff --git a/optimum_benchmark/backends/pytorch/config.py b/optimum_benchmark/backends/pytorch/config.py index a4202793..6d02813b 100644 --- a/optimum_benchmark/backends/pytorch/config.py +++ b/optimum_benchmark/backends/pytorch/config.py @@ -23,10 +23,6 @@ class PyTorchConfig(BackendConfig): device_map: Optional[str] = None torch_dtype: Optional[str] = None - # automatic mixed precision options - amp_autocast: bool = False - amp_dtype: Optional[str] = None - # optimization options eval_mode: bool = True to_bettertransformer: bool = False @@ -34,7 +30,11 @@ class PyTorchConfig(BackendConfig): attn_implementation: Optional[str] = None cache_implementation: Optional[str] = None - # compilation options + # automatic mixed precision options + autocast_enabled: bool = False + autocast_dtype: Optional[str] = None + + # torch compile options torch_compile: bool = False torch_compile_config: Dict[str, Any] = field(default_factory=dict) @@ -59,13 +59,14 @@ def __post_init__(self): if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES: raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.") - if self.amp_dtype is not None and self.amp_dtype not in AMP_DTYPES: - raise ValueError(f"`amp_dtype` must be one of {AMP_DTYPES}. Got {self.amp_dtype} instead.") + if self.autocast_dtype is not None and self.autocast_dtype not in AMP_DTYPES: + raise ValueError(f"`autocast_dtype` must be one of {AMP_DTYPES}. Got {self.autocast_dtype} instead.") if self.quantization_scheme is not None: if self.quantization_scheme not in QUANTIZATION_CONFIGS: raise ValueError( - f"`quantization_scheme` must be one of {list(QUANTIZATION_CONFIGS.keys())}. Got {self.quantization_scheme} instead." + f"`quantization_scheme` must be one of {list(QUANTIZATION_CONFIGS.keys())}. " + f"Got {self.quantization_scheme} instead." ) if self.quantization_scheme == "bnb" and is_rocm_system(): diff --git a/optimum_benchmark/backends/torch_ort/backend.py b/optimum_benchmark/backends/torch_ort/backend.py index 664f4a32..96c9b286 100644 --- a/optimum_benchmark/backends/torch_ort/backend.py +++ b/optimum_benchmark/backends/torch_ort/backend.py @@ -1,4 +1,3 @@ -import gc import os from logging import getLogger from tempfile import TemporaryDirectory @@ -110,10 +109,12 @@ def train( LOGGER.info("\t+ Finished training") def clean(self) -> None: - super().clean() + if self.config.device == "cuda" and torch.cuda.is_available(): + LOGGER.info("\t+ Emptying CUDA cache") + torch.cuda.empty_cache() if hasattr(self, "tmpdir"): LOGGER.info("\t+ Cleaning backend temporary directory") self.tmpdir.cleanup() - gc.collect() + super().clean() diff --git a/optimum_benchmark/benchmarks/report.py b/optimum_benchmark/benchmarks/report.py index a1a79a27..fd88616c 100644 --- a/optimum_benchmark/benchmarks/report.py +++ b/optimum_benchmark/benchmarks/report.py @@ -45,14 +45,10 @@ def aggregate(measurements: List["BenchmarkMeasurements"]) -> "BenchmarkMeasurem @dataclass class BenchmarkReport(PushToHubMixin): - @classmethod - def from_targets(cls, targets: List[str]) -> "BenchmarkReport": - return cls.from_dict({target: BenchmarkMeasurements() for target in targets}) - @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PushToHubMixin": return make_dataclass( - cls_name="report", fields=[(target, BenchmarkMeasurements) for target in data.keys()], bases=(cls,) + cls_name="BenchmarkReport", fields=[(target, BenchmarkMeasurements) for target in data.keys()], bases=(cls,) )(**data) def log_memory(self): @@ -106,7 +102,7 @@ def aggregate(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport": measurements = [getattr(report, target) for report in reports] aggregated_measurements[target] = BenchmarkMeasurements.aggregate(measurements) - return cls(**aggregated_measurements) + return cls.from_dict(aggregated_measurements) @classproperty def default_filename(self) -> str: diff --git a/optimum_benchmark/experiment.py b/optimum_benchmark/experiment.py index 36d35afe..97807f5d 100644 --- a/optimum_benchmark/experiment.py +++ b/optimum_benchmark/experiment.py @@ -45,22 +45,25 @@ def default_filename(cls) -> str: return "experiment_config.json" -def run(benchmark_config: BenchmarkConfig, backend_config: BackendConfig) -> BenchmarkReport: +def run(experiment_config: ExperimentConfig) -> BenchmarkReport: """ Runs a benchmark using specified backend and benchmark configurations """ # Allocate requested backend + backend_config: BackendConfig = experiment_config.backend backend_factory: Type[Backend] = get_class(backend_config._target_) backend: Backend = backend_factory(backend_config) # Allocate requested benchmark + benchmark_config: BenchmarkConfig = experiment_config.benchmark benchmark_factory: Type[Benchmark] = get_class(benchmark_config._target_) benchmark: Benchmark = benchmark_factory(benchmark_config) # Benchmark the backend benchmark.run(backend) report = benchmark.get_report() + backend.clean() return report @@ -70,35 +73,33 @@ def launch(experiment_config: ExperimentConfig) -> BenchmarkReport: Runs an experiment using specified launcher configuration/logic """ - # We keep track of the main benchmark process PID to be able to - # track its memory usage in isolated and distributed setups - os.environ["BENCHMARK_PID"] = str(os.getpid()) - if os.environ.get("BENCHMARK_INTERFACE", "API") == "API": # We launch the experiment in a temporary directory to avoid # polluting the current working directory with temporary files LOGGER.info("Launching experiment in a temporary directory.") - tmpdir = TemporaryDirectory() original_dir = os.getcwd() + tmpdir = TemporaryDirectory() os.chdir(tmpdir.name) try: # Allocate requested launcher launcher_config: LauncherConfig = experiment_config.launcher launcher_factory: Type[Launcher] = get_class(launcher_config._target_) - launcher: Launcher = launcher_factory(experiment_config.launcher) - report = launcher.launch(run, experiment_config.benchmark, experiment_config.backend) - error = None - except Exception as e: - LOGGER.error("Error during experiment") - error = e + launcher: Launcher = launcher_factory(launcher_config) + # Launch the experiment + report = launcher.launch(run, experiment_config) + except Exception as error: + LOGGER.error("Error during experiment", exc_info=True) + exception = error + else: + exception = None if os.environ.get("BENCHMARK_INTERFACE", "API") == "API": LOGGER.info("Cleaning up experiment temporary directory.") os.chdir(original_dir) tmpdir.cleanup() - if error is not None: - raise error + if exception is not None: + raise exception return report diff --git a/optimum_benchmark/launchers/config.py b/optimum_benchmark/launchers/config.py index 74c6ac94..e72eceb4 100644 --- a/optimum_benchmark/launchers/config.py +++ b/optimum_benchmark/launchers/config.py @@ -27,14 +27,15 @@ def __post_init__(self): if self.device_isolation and self.device_isolation_action is None: LOGGER.warning( "Device isolation is enabled but no action is specified. " - "Please set `device_isolation_action` to either 'error' or 'warn' " - "to specify the action. Defaulting to 'warn'." + "Please set `device_isolation_action` to either `error`, `warn`, or `kill`. " + "Defaulting to `warn`." ) self.device_isolation_action = "warn" - elif self.device_isolation and self.device_isolation_action not in {"error", "warn"}: + + elif self.device_isolation and self.device_isolation_action not in {"error", "warn", "kill"}: raise ValueError( f"Unsupported device isolation action {self.device_isolation_action}. " - "Please set `device_isolation_action` to either 'error' or 'warn'." + "Please set `device_isolation_action` to either `error`, `warn`, or `kill`." ) diff --git a/optimum_benchmark/launchers/device_isolation_utils.py b/optimum_benchmark/launchers/device_isolation_utils.py new file mode 100644 index 00000000..82a32f83 --- /dev/null +++ b/optimum_benchmark/launchers/device_isolation_utils.py @@ -0,0 +1,214 @@ +import multiprocessing as mp +import os +import signal +import time +from contextlib import contextmanager +from logging import getLogger +from typing import Optional, Set + +from ..import_utils import is_amdsmi_available, is_psutil_available, is_pynvml_available +from ..logging_utils import setup_logging +from ..system_utils import is_nvidia_system, is_rocm_system + +if is_psutil_available(): + import psutil + +if is_pynvml_available(): + import pynvml + +if is_amdsmi_available(): + import amdsmi # type: ignore + + +LOGGER = getLogger("device-isolation") + + +class DeviceIsolationError(Exception): + pass + + +def isolation_error_signal_handler(signum, frame): + raise DeviceIsolationError("Received an error signal from the device isolation process") + + +signal.signal(signal.SIGUSR1, isolation_error_signal_handler) + + +def get_nvidia_devices_pids(device_ids: str) -> Set[int]: + if not is_pynvml_available(): + raise ValueError( + "The library pynvml is required to get the pids running on NVIDIA GPUs, but is not installed. " + "Please install the official and NVIDIA maintained PyNVML library through `pip install nvidia-ml-py`." + ) + + pynvml.nvmlInit() + + devices_pids = set() + devices_ids = list(map(int, device_ids.split(","))) + + for device_id in devices_ids: + device_handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + device_processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle) + for device_process in device_processes: + devices_pids.add(device_process.pid) + + pynvml.nvmlShutdown() + + return devices_pids + + +def get_amd_devices_pids(device_ids: str) -> Set[int]: + if not is_amdsmi_available(): + raise ValueError( + "The library amdsmi is required to get the pids running on AMD GPUs, but is not installed. " + "Please install the official and AMD maintained amdsmi library from https://github.com/ROCm/amdsmi." + ) + + amdsmi.amdsmi_init() + + devices_pids = set() + devices_ids = list(map(int, device_ids.split(","))) + + processor_handles = amdsmi.amdsmi_get_processor_handles() + for device_id in devices_ids: + processor_handle = processor_handles[device_id] + try: + # these functions fail a lot for no apparent reason + processes_handles = amdsmi.amdsmi_get_gpu_process_list(processor_handle) + except Exception: + continue + + for process_handle in processes_handles: + try: + # these functions fail a lot for no apparent reason + info = amdsmi.amdsmi_get_gpu_process_info(processor_handle, process_handle) + except Exception: + continue + + if info["memory_usage"]["vram_mem"] == 4096: + # not sure why these processes are always present + continue + + devices_pids.add(info["pid"]) + + amdsmi.amdsmi_shut_down() + + return devices_pids + + +def get_pids_running_on_system_devices(device_ids: str) -> Set[int]: + """Returns the set of pids running on the system device(s).""" + if is_nvidia_system(): + devices_pids = get_nvidia_devices_pids(device_ids) + elif is_rocm_system(): + devices_pids = get_amd_devices_pids(device_ids) + else: + raise ValueError("get_pids_running_on_system_device is only supported on NVIDIA and AMD GPUs") + + return devices_pids + + +def get_children_pids(pid: int) -> Set[int]: + """Returns the set of pids of the children of the given process.""" + if not is_psutil_available(): + raise ValueError( + "The library psutil is required to get the children pids of a process, but is not installed. " + "Please install the official and cross-platform psutil library through `pip install psutil`." + ) + + if not psutil.pid_exists(pid): + LOGGER.warn(f"Process with pid [{pid}] does not exist.") + return set() + + process = psutil.Process(pid) + children = process.children(recursive=True) + children_pids = {child.pid for child in children} + + return children_pids + + +def assert_device_isolation(action: str, pid: int, device_ids: str): + setup_logging("INFO", prefix="DEVICE-ISOLATION-PROCESS") + + assert action in ["warn", "error", "kill"], f"Unsupported action `{action}`" + + while psutil.pid_exists(pid): + device_pids = get_pids_running_on_system_devices(device_ids=device_ids) + device_pids = {p for p in device_pids if psutil.pid_exists(p)} + + permitted_pids = {pid} | get_children_pids(pid) + permitted_pids = {p for p in permitted_pids if psutil.pid_exists(p)} + + foreign_pids = device_pids - permitted_pids + + if len(foreign_pids) > 0: + LOGGER.warn( + f"Found foreign process(es) [{foreign_pids}] running on the isolated device(s) [{device_ids}], " + f"other than the isolated process [{pid}] (and its children)." + ) + + if action == "warn": + LOGGER.warn("Make sure no other process is running on the isolated device(s) while benchmarking.") + elif action == "error": + LOGGER.error("Signaling the isolated process to error out...") + os.kill(pid, signal.SIGUSR1) + elif action == "kill": + LOGGER.error("Killing the isolated process...") + os.kill(pid, signal.SIGKILL) + + LOGGER.warn("Exiting the isolation process...") + exit(0) + + time.sleep(1) + + +@contextmanager +def device_isolation_context(enable: bool, action: str, pid: int, device_ids: Optional[str] = None): + if not enable: + yield + return + + if action is None: + raise ValueError("Device isolation requires the action to be specified") + elif action not in ["warn", "error", "kill"]: + raise ValueError(f"Unsupported action `{action}`") + + if pid is None: + raise ValueError("Device isolation requires the pid of the isolated process") + + if device_ids is None: + if is_nvidia_system(): + device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", None) + elif is_rocm_system(): + device_ids = os.environ.get("ROCR_VISIBLE_DEVICES", None) + + if device_ids is None: + raise ValueError( + "Device isolation requires the device_ids of the isolated device(s) to be specified. " + "Or for the environment variable `CUDA_VISIBLE_DEVICES` or `ROCR_VISIBLE_DEVICES` to be set." + ) + + if not (is_nvidia_system() or is_rocm_system()): + raise ValueError("Device isolation is only supported on NVIDIA and AMD GPUs") + + device_isolation_process = mp.Process( + target=assert_device_isolation, kwargs={"action": action, "pid": pid, "device_ids": device_ids}, daemon=True + ) + device_isolation_process.start() + + LOGGER.info( + f"\t+ Started device(s) isolation process [{device_isolation_process.pid}], monitoring " + f"the isolated process [{pid}], running on device(s) [{device_ids}], with action [{action}]." + ) + + yield + + device_isolation_process.terminate() + device_isolation_process.join(timeout=1) + + if device_isolation_process.is_alive(): + LOGGER.warn("The isolation process did not terminate gracefully. Killing it forcefully...") + device_isolation_process.kill() + device_isolation_process.join(timeout=1) + + device_isolation_process.close() diff --git a/optimum_benchmark/launchers/inline/config.py b/optimum_benchmark/launchers/inline/config.py index 1e4ff9c7..a1821f70 100644 --- a/optimum_benchmark/launchers/inline/config.py +++ b/optimum_benchmark/launchers/inline/config.py @@ -13,3 +13,13 @@ class InlineConfig(LauncherConfig): def __post_init__(self): super().__post_init__() + + if self.device_isolation: + raise ValueError( + "Device isolation is not supported with the inline launcher. Use `process` launcher instead." + ) + + if self.device_isolation_action is not None: + raise ValueError( + "Device isolation is not supported with the inline launcher. Use `process` launcher instead." + ) diff --git a/optimum_benchmark/launchers/inline/launcher.py b/optimum_benchmark/launchers/inline/launcher.py index 382d8461..77f5d089 100644 --- a/optimum_benchmark/launchers/inline/launcher.py +++ b/optimum_benchmark/launchers/inline/launcher.py @@ -1,10 +1,8 @@ -import os from logging import getLogger from typing import Callable from ...benchmarks.report import BenchmarkReport from ..base import Launcher -from ..isolation_utils import device_isolation from .config import InlineConfig LOGGER = getLogger("inline") @@ -17,12 +15,10 @@ def __init__(self, config: InlineConfig): super().__init__(config) def launch(self, worker: Callable, *worker_args) -> BenchmarkReport: - with device_isolation( - isolated_pid=os.getpid(), - enabled=self.config.device_isolation, - action=self.config.device_isolation_action, - ): - LOGGER.info("\t+ Launching benchmark in the main process.") - report = worker(*worker_args) + LOGGER.warn( + "\t+ Running benchmark in the main process. " + "This is only recommended for debugging purposes and not for benchmarking." + ) + report = worker(*worker_args) return report diff --git a/optimum_benchmark/launchers/isolation_utils.py b/optimum_benchmark/launchers/isolation_utils.py deleted file mode 100644 index 615a6097..00000000 --- a/optimum_benchmark/launchers/isolation_utils.py +++ /dev/null @@ -1,178 +0,0 @@ -import os -import signal -import time -from contextlib import contextmanager -from logging import getLogger -from multiprocessing import Process -from typing import Set - -from ..import_utils import is_amdsmi_available, is_psutil_available, is_pynvml_available -from ..logging_utils import setup_logging -from ..system_utils import is_nvidia_system, is_rocm_system - -if is_psutil_available(): - import psutil - -if is_pynvml_available(): - import pynvml - -if is_amdsmi_available(): - import amdsmi # type: ignore - - -LOGGER = getLogger("device-isolation") - - -def isolation_error_signal_handler(signum, frame): - LOGGER.error(f"Process {os.getpid()} received an isolation signal with an `error` action. Exiting...") - raise InterruptedError("Your device is not isolated (other processes are running on it). Exiting...") - - -def isolation_warn_signal_handler(signum, frame): - LOGGER.warn(f"Process {os.getpid()} received an isolation signal with a `warn` action. Ignoring...") - pass - - -signal.signal(signal.SIGUSR1, isolation_error_signal_handler) -signal.signal(signal.SIGUSR2, isolation_warn_signal_handler) - - -def get_nvidia_devices_pids(device_ids: str) -> Set[int]: - if not is_pynvml_available(): - raise ValueError( - "The library pynvml is required to get the pids running on NVIDIA GPUs, but is not installed. " - "Please install the official and NVIDIA maintained PyNVML library through `pip install nvidia-ml-py`." - ) - - pynvml.nvmlInit() - - devices_pids = set() - devices_ids = list(map(int, device_ids.split(","))) - - for device_id in devices_ids: - device_handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - device_processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle) - for device_process in device_processes: - devices_pids.add(device_process.pid) - - pynvml.nvmlShutdown() - - return devices_pids - - -def get_amd_devices_pids(device_ids: str) -> Set[int]: - if not is_amdsmi_available(): - raise ValueError( - "The library amdsmi is required to get the pids running on AMD GPUs, but is not installed. " - "Please install the official and AMD maintained amdsmi library from https://github.com/ROCm/amdsmi." - ) - - amdsmi.amdsmi_init() - - devices_pids = set() - devices_ids = list(map(int, device_ids.split(","))) - - processor_handles = amdsmi.amdsmi_get_processor_handles() - for device_id in devices_ids: - processor_handle = processor_handles[device_id] - try: - # these functions fail a lot for no apparent reason - processes_handles = amdsmi.amdsmi_get_gpu_process_list(processor_handle) - except Exception: - continue - - for process_handle in processes_handles: - try: - # these functions fail a lot for no apparent reason - info = amdsmi.amdsmi_get_gpu_process_info(processor_handle, process_handle) - except Exception: - continue - - if info["memory_usage"]["vram_mem"] == 4096: - # not sure why these processes are always present - continue - - devices_pids.add(info["pid"]) - - amdsmi.amdsmi_shut_down() - - return devices_pids - - -def get_pids_running_on_system_devices(device_ids: str) -> Set[int]: - """Returns the set of pids running on the system device(s).""" - if is_nvidia_system(): - devices_pids = get_nvidia_devices_pids(device_ids) - elif is_rocm_system(): - devices_pids = get_amd_devices_pids(device_ids) - else: - raise ValueError("get_pids_running_on_system_device is only supported on NVIDIA and AMD GPUs") - - return devices_pids - - -def assert_system_devices_isolation(isolated_pid: int, device_ids: str, action: str): - setup_logging("WARNING") - - if action == "error": - action_signal = signal.SIGUSR1 - elif action == "warn": - action_signal = signal.SIGUSR2 - else: - raise ValueError(f"Unsupported action {action}") - - while psutil.pid_exists(isolated_pid): - devices_pids = get_pids_running_on_system_devices(device_ids=device_ids) - devices_pids = {pid for pid in devices_pids if psutil.pid_exists(pid)} - permitted_pids = {isolated_pid} | {child.pid for child in psutil.Process(isolated_pid).children(recursive=True)} - non_permitted_pids = devices_pids - permitted_pids - - if len(non_permitted_pids) > 0: - LOGGER.warn(f"Found non-permitted process(es) running on system device(s): {non_permitted_pids}") - LOGGER.warn(f"Sending an action signal `{action}` to the isolated process {isolated_pid}...") - os.kill(isolated_pid, action_signal) - LOGGER.warn("Exiting...") - exit(0) - - time.sleep(1) - - -@contextmanager -def device_isolation(isolated_pid: int, enabled: bool, action: str): - if not enabled: - yield - return - - if is_nvidia_system(): - device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", None) - elif is_rocm_system(): - device_ids = os.environ.get("ROCR_VISIBLE_DEVICES", None) - else: - raise ValueError("Device isolation is only supported on NVIDIA and AMD GPUs") - - if device_ids is None: - raise ValueError( - "Device isolation requires CUDA_VISIBLE_DEVICES or ROCR_VISIBLE_DEVICES to be set but none were found." - ) - - isolation_process = Process( - target=assert_system_devices_isolation, - kwargs={ - "isolated_pid": isolated_pid, - "device_ids": device_ids, - "action": action, - }, - daemon=True, - ) - isolation_process.start() - - LOGGER.info(f"\t+ Launched device(s) isolation process {isolation_process.pid}") - LOGGER.info(f"\t+ Isolating device(s) [{device_ids}]") - - yield - - if isolation_process.is_alive(): - LOGGER.info("\t+ Closing device(s) isolation process...") - isolation_process.kill() - isolation_process.join() - isolation_process.close() diff --git a/optimum_benchmark/launchers/process/launcher.py b/optimum_benchmark/launchers/process/launcher.py index 246c701a..7fdbb9d9 100644 --- a/optimum_benchmark/launchers/process/launcher.py +++ b/optimum_benchmark/launchers/process/launcher.py @@ -1,13 +1,12 @@ +import multiprocessing as mp import os from logging import getLogger from typing import Callable -import torch.multiprocessing as mp - from ...benchmarks.report import BenchmarkReport from ...logging_utils import setup_logging from ..base import Launcher -from ..isolation_utils import device_isolation +from ..device_isolation_utils import device_isolation_context from .config import ProcessConfig LOGGER = getLogger("process") @@ -24,41 +23,33 @@ def __init__(self, config: ProcessConfig): mp.set_start_method(self.config.start_method, force=True) def launch(self, worker: Callable, *worker_args) -> BenchmarkReport: - log_level = getLogger().getEffectiveLevel() - ctx = mp.get_context(self.config.start_method) + log_level = ctx.get_logger().getEffectiveLevel() queue = ctx.Queue() lock = ctx.Lock() - with device_isolation( - isolated_pid=os.getpid(), - enabled=self.config.device_isolation, - action=self.config.device_isolation_action, + isolated_process = mp.Process(target=target, args=(worker, queue, lock, log_level, *worker_args), daemon=False) + isolated_process.start() + + with device_isolation_context( + enable=self.config.device_isolation, action=self.config.device_isolation_action, pid=isolated_process.pid ): - process_context = mp.start_processes( - entrypoint, - args=(worker, queue, lock, log_level, *worker_args), - start_method=self.config.start_method, - daemon=False, - join=False, - nprocs=1, - ) - LOGGER.info(f"\t+ Launched benchmark in isolated process {process_context.pids()[0]}.") - while not process_context.join(): - pass + isolated_process.join() + + if isolated_process.exitcode != 0: + raise RuntimeError(f"Process exited with non-zero code {isolated_process.exitcode}") + elif queue.empty(): + raise RuntimeError("No report was returned by the isolated process.") report: BenchmarkReport = queue.get() return report -def entrypoint(i, worker, queue, lock, log_level, *worker_args): - """ - This a pickalable function that correctly sets up the logging configuration for the worker process, - and puts the output of the worker function into a lock-protected queue. - """ - - setup_logging(log_level, prefix=f"PROC-{i}") +def target(worker, queue, lock, log_level, *worker_args): + os.environ["ISOLATED_PROCESS_PID"] = str(os.getpid()) + setup_logging(level=log_level, prefix="ISOLATED-PROCESS") + LOGGER.info(f"Running benchmark in isolated process [{os.getpid()}].") worker_output = worker(*worker_args) diff --git a/optimum_benchmark/launchers/torchrun/config.py b/optimum_benchmark/launchers/torchrun/config.py index ff816315..59e9aa75 100644 --- a/optimum_benchmark/launchers/torchrun/config.py +++ b/optimum_benchmark/launchers/torchrun/config.py @@ -21,7 +21,7 @@ class TorchrunConfig(LauncherConfig): # On each node the elastic agent will launch this amount of workers that will execute user defined function. nproc_per_node: int = 2 # User defined role of the worker (defaults to "trainer"). - role: str = "benchmark_worker" + role: str = "benchmarker" # The interval in seconds that is used by the elastic_agent as a period of monitoring workers. monitor_interval: int = 30 # The name of the rdzv store. @@ -31,20 +31,13 @@ class TorchrunConfig(LauncherConfig): # The endpoint of the rdzv sync. storage. rdzv_endpoint: str = "localhost:0" # Key, value pair that specifies rendezvous specific configuration. - rdzv_configs: Dict[str, Any] = field(default_factory=lambda: {"rank": 0, "timeout": 900}) + rdzv_configs: Dict[str, Any] = field(default_factory=lambda: {"rank": 0, "timeout": -1}) + # The timeout in seconds that is used by the elastic agent to wait for the workers to enter the rendezvous. + rdzv_timeout: int = -1 # The maximum amount of restarts that elastic agent will conduct on workers before failure. max_restarts: int = 0 # The method is used by the elastic agent to start the workers (spawn, fork, forkserver). start_method: str = "spawn" - # base log directory where log files are written. If not set, one is created in a tmp dir but NOT removed on exit. - log_dir: Optional[str] = None - # configuration to redirect stdout/stderr to log files. - # Pass a single Std enum to redirect all workers, or a mapping keyed by local_rank to selectively redirect. - redirects: str = "0" # Std.NONE - # configuration to "tee" stdout/stderr to console + log file. - tee: str = "0" # Std.NONE - # configuration to initialize metrics. - metrics_cfg: Dict[str, str] = field(default_factory=lambda: {}) # address of the local node if any. If not set, a lookup on the local machine's FQDN will be performed. local_addr: Optional[str] = None diff --git a/optimum_benchmark/launchers/torchrun/launcher.py b/optimum_benchmark/launchers/torchrun/launcher.py index 394dc78f..2dc60dbc 100644 --- a/optimum_benchmark/launchers/torchrun/launcher.py +++ b/optimum_benchmark/launchers/torchrun/launcher.py @@ -1,17 +1,16 @@ +import multiprocessing as mp import os from logging import getLogger -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict import torch.distributed -import torch.multiprocessing as mp -from torch.distributed.elastic.multiprocessing import Std from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.launcher.api import LaunchConfig, launch_agent +from torch.distributed.launcher.api import LaunchConfig, elastic_launch from ...benchmarks.report import BenchmarkReport from ...logging_utils import setup_logging from ..base import Launcher -from ..isolation_utils import device_isolation +from ..device_isolation_utils import device_isolation_context from .config import TorchrunConfig LOGGER = getLogger("torchrun") @@ -27,78 +26,92 @@ def __init__(self, config: TorchrunConfig): LOGGER.info(f"\t+ Setting multiprocessing start method to {self.config.start_method}.") mp.set_start_method(self.config.start_method, force=True) - def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]: - log_level = getLogger().getEffectiveLevel() - launch_config = LaunchConfig( + self.launch_config = LaunchConfig( min_nodes=self.config.min_nodes, max_nodes=self.config.max_nodes, nproc_per_node=self.config.nproc_per_node, - role=self.config.role, - monitor_interval=self.config.monitor_interval, run_id=self.config.rdzv_id, + role=self.config.role, rdzv_endpoint=self.config.rdzv_endpoint, rdzv_backend=self.config.rdzv_backend, rdzv_configs=self.config.rdzv_configs, + rdzv_timeout=self.config.rdzv_timeout, max_restarts=self.config.max_restarts, + monitor_interval=self.config.monitor_interval, start_method=self.config.start_method, - metrics_cfg=self.config.metrics_cfg, - redirects=Std.from_str(self.config.redirects), - tee=Std.from_str(self.config.tee), local_addr=self.config.local_addr, - log_dir=self.config.log_dir, ) + def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]: ctx = mp.get_context(self.config.start_method) + log_level = ctx.get_logger().getEffectiveLevel() queue = ctx.Queue() lock = ctx.Lock() - with device_isolation( - isolated_pid=os.getpid(), - enabled=self.config.device_isolation, - action=self.config.device_isolation_action, + isolated_process = mp.Process( + target=target, + args=(worker, queue, lock, log_level, *worker_args), + kwargs={"launch_config": self.launch_config}, + daemon=False, + ) + isolated_process.start() + + with device_isolation_context( + enable=self.config.device_isolation, action=self.config.device_isolation_action, pid=isolated_process.pid ): - LOGGER.info(f"\t+ Launching torchrun agent with {self.config.nproc_per_node} worker processes") - launch_agent( - entrypoint=entrypoint, args=(worker, queue, lock, log_level, *worker_args), config=launch_config - ) + isolated_process.join() - reports: List[BenchmarkReport] = [] + if isolated_process.exitcode != 0: + raise RuntimeError(f"Process exited with non-zero code {isolated_process.exitcode}.") + elif queue.empty(): + raise RuntimeError("No report was returned by the isolated process.") + reports = [] while not queue.empty(): reports.append(queue.get()) - if len(reports) > 1: - LOGGER.info(f"\t+ Merging benchmark reports from {len(reports)} workers") - report = reports[0].aggregate(reports) - elif len(reports) == 1: - report = reports[0] - else: - raise ValueError("No benchmark report was returned by the workers") + if len(reports) != self.config.nproc_per_node: + raise RuntimeError( + f"Number of gathered reports ({len(reports)}) does not match the number of processes ({self.config.nproc_per_node})." + ) - # Log the final report + report = BenchmarkReport.aggregate(reports) report.log() return report +def target(worker, queue, lock, log_level, *worker_args, launch_config: LaunchConfig): + os.environ["ISOLATED_PROCESS_PID"] = str(os.getpid()) + setup_logging(level=log_level, prefix="ISOLATED-PROCESS") + LOGGER.info(f"Running benchmark in isolated process [{os.getpid()}].") + + elastic_agent_launcher = elastic_launch(config=launch_config, entrypoint=entrypoint) + elastic_agent_launcher(worker, queue, lock, log_level, *worker_args) + + @record def entrypoint(worker, queue, lock, log_level, *worker_args): - """ - This a pickalable function that correctly sets up the logging configuration - """ + torch.distributed.init_process_group() + rank = torch.distributed.get_rank() - rank = int(os.environ["RANK"]) - torch.cuda.set_device(rank) if torch.cuda.is_available() else None - setup_logging(level=log_level, prefix=f"RANK-{rank}") if rank == 0 else setup_logging(level="ERROR") + if rank == "0": + setup_logging(level=log_level, prefix=f"RANK-{rank}") + elif os.environ.get("LOG_ALL_RANKS", None) == "1": + setup_logging(level=log_level, prefix=f"RANK-{rank}") + else: + setup_logging(level="ERROR", prefix=f"RANK-{rank}") - torch.distributed.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo") - torch.distributed.barrier() + if torch.cuda.is_available(): + LOGGER.info("\t+ Setting torch.distributed cuda device") + torch.cuda.set_device(rank % torch.cuda.device_count()) + torch.distributed.barrier() output = worker(*worker_args) - torch.distributed.barrier() - torch.distributed.destroy_process_group() lock.acquire() queue.put(output) lock.release() + + torch.distributed.destroy_process_group() diff --git a/optimum_benchmark/trackers/energy.py b/optimum_benchmark/trackers/energy.py index 4b179606..ffd6e887 100644 --- a/optimum_benchmark/trackers/energy.py +++ b/optimum_benchmark/trackers/energy.py @@ -114,8 +114,8 @@ def __init__(self, backend: str, device: str, device_ids: Optional[str] = None): self.device = device self.backend = backend self.device_ids = device_ids - self.asynchronous = backend == "pytorch" and device == "cuda" - self.distributed = is_torch_distributed_available() and torch.distributed.is_initialized() + self.is_asynchronous = backend == "pytorch" and device == "cuda" + self.is_distributed = is_torch_distributed_available() and torch.distributed.is_initialized() if self.device == "cuda": if self.device_ids is None: @@ -164,28 +164,28 @@ def __init__(self, backend: str, device: str, device_ids: Optional[str] = None): @contextmanager def track(self, file_prefix: str = "task"): - if self.asynchronous: - torch.cuda.synchronize() - - if self.distributed: + if self.is_distributed: torch.distributed.barrier() + if self.is_asynchronous: + torch.cuda.synchronize() + self.emission_tracker.start_task() yield + if self.is_distributed: + torch.distributed.barrier() + + if self.is_asynchronous: + torch.cuda.synchronize() + emission_data: EmissionsData = self.emission_tracker.stop_task() with open(f"{file_prefix}_codecarbon.json", "w") as f: LOGGER.info(f"\t+ Saving codecarbon emission data to {file_prefix}_codecarbon.json") dump(asdict(emission_data), f, indent=4) - if self.distributed: - torch.distributed.barrier() - - if self.asynchronous: - torch.cuda.synchronize() - self.cpu_energy = emission_data.cpu_energy self.gpu_energy = emission_data.gpu_energy self.ram_energy = emission_data.ram_energy diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 3f0a3c62..340fcc61 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -123,10 +123,10 @@ class LatencyTracker: def __init__(self, device: str, backend: str): self.device = device self.backend = backend - self.asynchronous = self.backend == "pytorch" and self.device == "cuda" - self.distributed = is_torch_distributed_available() and torch.distributed.is_initialized() + self.is_asynchronous = self.backend == "pytorch" and self.device == "cuda" + self.is_distributed = is_torch_distributed_available() and torch.distributed.is_initialized() - if self.asynchronous: + if self.is_asynchronous: LOGGER.info("\t+ Tracking latency using Pytorch CUDA events") else: LOGGER.info("\t+ Tracking latency using CPU performance counter") @@ -142,15 +142,15 @@ def reset(self): @contextmanager def track(self): - if self.distributed: + if self.is_distributed: torch.distributed.barrier() - if self.asynchronous: + if self.is_asynchronous: yield from self._pytorch_cuda_latency() else: yield from self._cpu_latency() - if self.distributed: + if self.is_distributed: torch.distributed.barrier() def _pytorch_cuda_latency(self): @@ -170,8 +170,9 @@ def _cpu_latency(self): self.end_events.append(time.perf_counter()) def get_latency(self) -> Latency: - if self.backend == "pytorch" and self.device == "cuda": - torch.cuda.synchronize() # synchronize the device to make sure all events have been recorded + if self.is_asynchronous: + torch.cuda.synchronize() + latencies_list = [ self.start_events[i].elapsed_time(self.end_events[i]) / 1e3 for i in range(len(self.start_events)) ] @@ -204,7 +205,13 @@ class StepLatencyTrainerCallback(TrainerCallback): def __init__(self, device: str, backend: str) -> None: self.device = device self.backend = backend - self.asynchronous = self.backend == "pytorch" and self.device == "cuda" + self.is_asynchronous = self.backend == "pytorch" and self.device == "cuda" + self.is_distributed = is_torch_distributed_available() and torch.distributed.is_initialized() + + if self.is_asynchronous: + LOGGER.info("\t+ Tracking latency using Pytorch CUDA events") + else: + LOGGER.info("\t+ Tracking latency using CPU performance counter") self.start_events: List[Union[float, torch.cuda.Event]] = [] self.end_events: List[Union[float, torch.cuda.Event]] = [] @@ -214,22 +221,23 @@ def reset(self): self.end_events = [] def on_step_begin(self, *args, **kwargs): - if self.asynchronous: + if self.is_asynchronous: self.start_events.append(torch.cuda.Event(enable_timing=True)) self.start_events[-1].record() else: self.start_events.append(time.perf_counter()) def on_step_end(self, *args, **kwargs): - if self.asynchronous: + if self.is_asynchronous: self.end_events.append(torch.cuda.Event(enable_timing=True)) self.end_events[-1].record() else: self.end_events.append(time.perf_counter()) def get_latency(self) -> Latency: - if self.asynchronous: - torch.cuda.synchronize() # synchronize the device to make sure all events have been recorded + if self.is_asynchronous: + torch.cuda.synchronize() + latencies_list = [ self.start_events[i].elapsed_time(self.end_events[i]) / 1e3 for i in range(len(self.start_events)) ] @@ -245,8 +253,13 @@ class PerTokenLatencyLogitsProcessor(LogitsProcessor): def __init__(self, device: str, backend: str): self.device = device self.backend = backend - self.asynchronous = self.backend == "pytorch" and self.device == "cuda" - self.distributed = is_torch_distributed_available() and torch.distributed.is_initialized() + self.is_asynchronous = self.backend == "pytorch" and self.device == "cuda" + self.is_distributed = is_torch_distributed_available() and torch.distributed.is_initialized() + + if self.is_asynchronous: + LOGGER.info("\t+ Tracking latency using Pytorch CUDA events") + else: + LOGGER.info("\t+ Tracking latency using CPU performance counter") self.start_time: Optional[float] = None self.next_is_prefill_end_decode_start: Optional[bool] = None @@ -269,10 +282,10 @@ def reset(self): @contextmanager def track(self): - if self.distributed: + if self.is_distributed: torch.distributed.barrier() - if self.asynchronous: + if self.is_asynchronous: self.prefill_start_events.append(torch.cuda.Event(enable_timing=True)) self.prefill_start_events[-1].record() else: @@ -284,13 +297,13 @@ def track(self): self.next_is_prefill_end_decode_start = None - if self.asynchronous: + if self.is_asynchronous: self.decode_end_events.append(torch.cuda.Event(enable_timing=True)) self.decode_end_events[-1].record() else: self.decode_end_events.append(time.perf_counter()) - if self.distributed: + if self.is_distributed: torch.distributed.barrier() def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): @@ -298,7 +311,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): self.next_is_prefill_end_decode_start is not None ), "PerTokenLatencyLogitsProcessor should only be called inside of track() context" - if self.asynchronous: + if self.is_asynchronous: event = torch.cuda.Event(enable_timing=True) event.record() else: @@ -314,8 +327,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): return scores def get_prefill_latency(self) -> Latency: - if self.asynchronous: - torch.cuda.synchronize() # synchronize the device to make sure all events have been recorded + if self.is_asynchronous: + torch.cuda.synchronize() + latencies_list = [ self.prefill_start_events[i].elapsed_time(self.prefill_end_events[i]) / 1e3 for i in range(len(self.prefill_start_events)) @@ -331,8 +345,9 @@ def get_prefill_latency(self) -> Latency: return Latency.from_values(latencies_list, unit=LATENCY_UNIT) def get_decode_latency(self) -> Latency: - if self.asynchronous: - torch.cuda.synchronize() # synchronize the device to make sure all events have been recorded + if self.is_asynchronous: + torch.cuda.synchronize() + latencies_list = [ self.decode_start_events[i].elapsed_time(self.decode_end_events[i]) / 1e3 for i in range(len(self.decode_start_events)) @@ -347,8 +362,9 @@ def get_decode_latency(self) -> Latency: return Latency.from_values(latencies_list, unit=LATENCY_UNIT) def get_per_token_latency(self) -> Latency: - if self.asynchronous: - torch.cuda.synchronize() # synchronize the device to make sure all events have been recorded + if self.is_asynchronous: + torch.cuda.synchronize() + latencies_list = [ self.per_token_events[i].elapsed_time(self.per_token_events[i + 1]) / 1e3 for i in range(0, len(self.per_token_events) - 1) diff --git a/optimum_benchmark/trackers/memory.py b/optimum_benchmark/trackers/memory.py index dea3ba37..5e7345ae 100644 --- a/optimum_benchmark/trackers/memory.py +++ b/optimum_benchmark/trackers/memory.py @@ -95,21 +95,25 @@ def log(self, prefix: str = "forward"): class MemoryTracker: - def __init__(self, device: str, backend: str, device_ids: Optional[str] = None): + def __init__( + self, device: str, backend: str, device_ids: Optional[str] = None, monitored_pid: Optional[int] = None + ): self.device = device self.backend = backend self.device_ids = device_ids - self.monitored_pid = int(os.environ.get("BENCHMARK_PID", os.getpid())) - self.track_cuda_pytorch_memory = self.device == "cuda" and self.backend == "pytorch" - self.distributed = is_torch_distributed_available() and torch.distributed.is_initialized() + self.monitored_pid = monitored_pid + self.uses_cuda_pytorch_allocator = self.device == "cuda" and self.backend == "pytorch" + self.is_distributed = is_torch_distributed_available() and torch.distributed.is_initialized() - LOGGER.info("\t+ Tracking RAM memory") + if self.monitored_pid is None: + self.monitored_pid = int(os.environ.get("ISOLATED_PROCESS_PID", os.getpid())) + LOGGER.info(f"\t+ Tracking RAM memory of process with PID [{self.monitored_pid}]") if self.device == "cuda": self.device_ids = list(map(int, self.device_ids.split(","))) - LOGGER.info(f"\t+ Tracking VRAM memory of CUDA devices {self.device_ids}") + LOGGER.info(f"\t+ Tracking VRAM memory of CUDA devices with IDs [{self.device_ids}]") - if self.track_cuda_pytorch_memory: + if self.uses_cuda_pytorch_allocator: self.num_pytorch_devices = torch.cuda.device_count() if len(self.device_ids) != self.num_pytorch_devices: raise ValueError( @@ -133,24 +137,22 @@ def reset(self): @contextmanager def track(self): - if self.distributed: + if self.is_distributed: torch.distributed.barrier() - if self.track_cuda_pytorch_memory: + if self.uses_cuda_pytorch_allocator: yield from self._cuda_pytorch_memory() elif self.device == "cuda": yield from self._cuda_memory() else: yield from self._cpu_memory() - if self.distributed: + if self.is_distributed: torch.distributed.barrier() def _cuda_pytorch_memory(self): - torch.cuda.empty_cache() - torch.cuda.synchronize() - for device in range(self.num_pytorch_devices): + torch.cuda.synchronize(device=device) try: torch.cuda.reset_peak_memory_stats(device=device) except Exception as e: @@ -158,15 +160,15 @@ def _cuda_pytorch_memory(self): yield from self._cuda_memory() - self.max_allocated_memory = sum( - torch.cuda.max_memory_allocated(device=device) / 1e6 for device in range(self.num_pytorch_devices) - ) - self.max_reserved_memory = sum( - torch.cuda.max_memory_reserved(device=device) / 1e6 for device in range(self.num_pytorch_devices) - ) + self.max_allocated_memory = 0 + self.max_reserved_memory = 0 - torch.cuda.synchronize() - torch.cuda.empty_cache() + for device in range(self.num_pytorch_devices): + try: + self.max_allocated_memory += torch.cuda.max_memory_allocated(device=device) / 1e6 + self.max_reserved_memory += torch.cuda.max_memory_reserved(device=device) / 1e6 + except Exception as e: + LOGGER.warning(f"\t\t+ Could not get max memory stats for device {device}: {e}") def _cuda_memory(self): child_connection, parent_connection = Pipe() diff --git a/tests/configs/_base_.yaml b/tests/configs/_base_.yaml index 84a3f497..de3c7394 100644 --- a/tests/configs/_base_.yaml +++ b/tests/configs/_base_.yaml @@ -5,6 +5,8 @@ defaults: - benchmark: inference # default benchmark - override hydra/hydra_logging: colorlog # colored logging - override hydra/job_logging: colorlog # colored logging + - _self_ + # hydra/cli specific settings hydra: diff --git a/tests/configs/cuda_inference_pytorch_text_decoders.yaml b/tests/configs/cuda_inference_pytorch_text_decoders.yaml index 45e8db9e..dfb480fd 100644 --- a/tests/configs/cuda_inference_pytorch_text_decoders.yaml +++ b/tests/configs/cuda_inference_pytorch_text_decoders.yaml @@ -9,4 +9,4 @@ defaults: - _self_ # hydra 1.1 compatibility - override backend: pytorch -experiment_name: cuda_inference_pytorch_text_decoders_no_weights +experiment_name: cuda_inference_pytorch_text_decoders diff --git a/tests/configs/cuda_inference_pytorch_text_encoders.yaml b/tests/configs/cuda_inference_pytorch_text_encoders.yaml index 8fbbacc0..44bc2e7c 100644 --- a/tests/configs/cuda_inference_pytorch_text_encoders.yaml +++ b/tests/configs/cuda_inference_pytorch_text_encoders.yaml @@ -9,4 +9,4 @@ defaults: - _self_ # hydra 1.1 compatibility - override backend: pytorch -experiment_name: cuda_inference_pytorch_text_encoders_no_weights +experiment_name: cuda_inference_pytorch_text_encoders diff --git a/tests/test_api.py b/tests/test_api.py index 52432057..29040360 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -13,10 +13,7 @@ get_diffusers_pretrained_config, ) from optimum_benchmark.backends.pytorch.config import PyTorchConfig -from optimum_benchmark.backends.timm_utils import ( - extract_timm_shapes_from_config, - get_timm_pretrained_config, -) +from optimum_benchmark.backends.timm_utils import extract_timm_shapes_from_config, get_timm_pretrained_config from optimum_benchmark.backends.transformers_utils import ( extract_transformers_shapes_from_artifacts, get_transformers_pretrained_config, @@ -55,10 +52,7 @@ def test_api_launch(device, benchmark, library, task, model): device_ids = get_gpu_device_ids() if device == "cuda" else None no_weights = False if library != "transformers" else True - launcher_config = ProcessConfig( - device_isolation=device == "cuda", - device_isolation_action="error", - ) + launcher_config = ProcessConfig(device_isolation=device == "cuda", device_isolation_action="error") if benchmark == "training": if library == "transformers": @@ -80,20 +74,12 @@ def test_api_launch(device, benchmark, library, task, model): ) backend_config = PyTorchConfig( - device=device, - device_ids=device_ids, - no_weights=no_weights, - library=library, - model=model, - task=task, + device=device, device_ids=device_ids, no_weights=no_weights, library=library, model=model, task=task ) experiment_name = f"{device}_{benchmark}_{library}_{task}_{model}" experiment_config = ExperimentConfig( - experiment_name=experiment_name, - benchmark=benchmark_config, - launcher=launcher_config, - backend=backend_config, + experiment_name=experiment_name, benchmark=benchmark_config, launcher=launcher_config, backend=backend_config ) benchmark_report = launch(experiment_config) @@ -109,10 +95,7 @@ def test_api_push_to_hub_mixin(): benchmark_config = InferenceConfig(memory=True, latency=True, duration=1, iterations=1, warmup_runs=1) experiment_config = ExperimentConfig( - experiment_name=experiment_name, - benchmark=benchmark_config, - launcher=launcher_config, - backend=backend_config, + experiment_name=experiment_name, benchmark=benchmark_config, launcher=launcher_config, backend=backend_config ) benchmark_report = launch(experiment_config) diff --git a/tests/test_cli.py b/tests/test_cli.py index ea2e8341..1c5d56d1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,6 +7,7 @@ LOGGER = getLogger("test-cli") + TEST_CONFIG_DIR = "/".join(__file__.split("/")[:-1] + ["configs"]) TEST_CONFIG_NAMES = [ config.split(".")[0] @@ -25,22 +26,29 @@ def test_cli_configs(config_name): config_name, # to run the tests faster (comment for debugging) "hydra/launcher=joblib", + "hydra.launcher.n_jobs=-1", + "hydra.launcher.batch_size=1", + "hydra.launcher.prefer=threads", ] popen = run_subprocess_and_log_stream_output(LOGGER, args) assert popen.returncode == 0, f"Failed to run {config_name}" -def test_cli_exit_code(): +@pytest.mark.parametrize("launcher", ["inline", "process", "torchrun"]) +def test_cli_exit_code(launcher): args_0 = [ "optimum-benchmark", "--config-dir", TEST_CONFIG_DIR, "--config-name", - "cpu_inference_pytorch_text_encoders", + "_base_", + f"launcher={launcher}", + "experiment_name=test", # compatible task and model "backend.task=text-classification", "backend.model=bert-base-uncased", + "backend.device=cpu", ] popen_0 = run_subprocess_and_log_stream_output(LOGGER, args_0) @@ -51,10 +59,13 @@ def test_cli_exit_code(): "--config-dir", TEST_CONFIG_DIR, "--config-name", - "cpu_inference_pytorch_text_encoders", + "_base_", + f"launcher={launcher}", + "experiment_name=test", # incompatible task and model to trigger error "backend.task=image-classification", "backend.model=bert-base-uncased", + "backend.device=cpu", ] popen_1 = run_subprocess_and_log_stream_output(LOGGER, args_1)