From ffe0af23575c4f03a07408eacfc50b1a58781429 Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Thu, 8 Aug 2024 09:28:32 -0700 Subject: [PATCH 01/13] Fix the bug of deepspeed sequence parallel working with batch size larger than 1 (#5823) Modified the `alltoall` function Verified the results with only `TP`: ![image](https://github.com/user-attachments/assets/9bdd8942-3565-418f-b7be-614293b2f2f6) --------- Co-authored-by: Jinghan Yao Co-authored-by: Sam Ade Jacobs Co-authored-by: Jinghan Yao Co-authored-by: Logan Adams --- deepspeed/sequence/layer.py | 107 +++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 38 deletions(-) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index f17cfa883cc6..e809fe1118b5 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -12,48 +12,76 @@ from deepspeed.accelerator import get_accelerator -def post_all2all(transpose, res_shape): +def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim): def post_func(input): - if transpose: - input = input.transpose(0, 2).contiguous() - input = input.reshape(res_shape) - return input + if batch_dim_idx == 0: + # b, s, n, h + if scatter_idx < 2: + output = input.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head, + head_dim).contiguous() + else: + output = input.permute(1, 0, 2, 3, 4).contiguous() + output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size, + head_dim).contiguous() + else: + # s, b, n, h + if scatter_idx < 2: + output = input.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head, + head_dim).contiguous() + else: + output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous() + return output return post_func -def single_all_to_all(input, scatter_idx, gather_idx, group, async_op=False, handle=None, type=None): +def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None): seq_world_size = dist.get_world_size(group) - inp_shape = list(input.shape) - inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size + if batch_dim_idx == 0: + # b, s, n, h + if scatter_idx < 2: + bs, global_seq_len, num_local_head, head_dim = input.shape + input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, + head_dim]).contiguous() + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() + else: + bs, local_seq_len, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, + head_dim]).contiguous() + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + else: + # s, b, n, h + if scatter_idx < 2: + global_seq_len, bs, num_local_head, head_dim = input.shape + input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, + head_dim]).contiguous() + else: + local_seq_len, bs, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!" + input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, + head_dim]).contiguous() + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + if scatter_idx < 2: - input_t = input.reshape( - [seq_world_size, inp_shape[scatter_idx]] + \ - inp_shape[scatter_idx + 1:] - ).contiguous() + post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head, + head_dim) else: - # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! - input_t = input.reshape( - [-1, seq_world_size, inp_shape[scatter_idx]] + \ - inp_shape[scatter_idx + 1:] - ).transpose(0, 1).contiguous() + post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head, + head_dim) output = torch.empty_like(input_t) work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) - res_shape=( inp_shape[: gather_idx] + \ - [inp_shape[gather_idx] * seq_world_size,] + \ - inp_shape[gather_idx + 1:]) - transpose = True if scatter_idx < 2 else False - post_all2all_fun = post_all2all(transpose, res_shape) - if async_op: if type in ('dq', 'dk'): handle[type + '_work'] = work handle[type + '_grad'] = output handle[type + '_post_all2all_func'] = post_all2all_fun - return output.view(res_shape) + return output res = post_all2all_fun(output) return res @@ -67,6 +95,7 @@ def forward(ctx: Any, input: Tensor, scatter_idx: int, gather_idx: int, + batch_dim_idx: int, stream=None, handle=None, type=None, @@ -77,14 +106,15 @@ def forward(ctx: Any, ctx.stream = stream ctx.handle = handle ctx.type = type + ctx.batch_dim_idx = batch_dim_idx if ctx.handle is None: - res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) else: # overlap communication path if not is_fwd and type == 'o': assert ctx.stream != None - res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) get_accelerator().current_stream().wait_stream(ctx.stream) del ctx.stream.activation_buffer_list # The computation of d o_weight can overlap with the communication of d o_input @@ -92,15 +122,15 @@ def forward(ctx: Any, elif not is_fwd and type in ('q', 'k'): # Achieve communication overlap by pipelining the matrix computation and communication of dq, dk, and dv type = 'd' + type - res = single_all_to_all(input, scatter_idx, gather_idx, group, True, handle, type) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, True, handle, type) elif is_fwd and type in ('q', 'k'): # Achieve communication overlap by pipelining the matrix computation and communication of q, k, and v type = 'fwd_' + type - res = single_all_to_all(input, scatter_idx, gather_idx, group, False, handle, type) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False, handle, type) else: - res = single_all_to_all(input, scatter_idx, gather_idx, group, False) + res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False) return res @@ -108,8 +138,8 @@ def forward(ctx: Any, def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: return (None, - _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.stream, ctx.handle, - ctx.type, False), None, None, None, None, None, None) + _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.batch_dim_idx, + ctx.stream, ctx.handle, ctx.type, False), None, None, None, None, None, None, None) class DistributedAttention(torch.nn.Module): @@ -148,13 +178,14 @@ def layer_sync(self, layer): if self.sp_overlap_comm and hasattr(layer, 'done_event'): self.dafult_stream.wait_event(layer.done_event) - def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: + def forward(self, query: Tensor, key: Tensor, value: Tensor, batch_dim_idx: int, *args: Any, **kwargs) -> Tensor: """ forward Arguments: query (Tensor): query input to the layer key (Tensor): key input to the layer value (Tensor): value input to the layer + batch_dim_idx (int): indicating which dim is batch args: other args Returns: @@ -179,15 +210,15 @@ def pre_hook_fun(grad): return pre_hook_fun self.layer_sync(query) - query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, None, + query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'q') self.layer_sync(key) - key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, None, self.overlap_handles, - 'k') + key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None, + self.overlap_handles, 'k') if self.sp_overlap_comm: self.dafult_stream.wait_stream(self.sp_stream) - value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, None, + value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'v') if self.sp_overlap_comm: @@ -205,8 +236,8 @@ def pre_hook_fun(grad): context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, self.sp_stream, - self.overlap_handles, 'o') + output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx, batch_dim_idx, + self.sp_stream, self.overlap_handles, 'o') #out e.g., [s/p::h] return output From 1890b814244825afdb29f070352dae406f86ce73 Mon Sep 17 00:00:00 2001 From: vikram singh shekhawat Date: Wed, 14 Aug 2024 04:36:31 +0530 Subject: [PATCH 02/13] Upgrade HPU image to v1.16.2. (#5610) Upgraded HPU/Gaudi image to v1.16.2. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Logan Adams --- .github/workflows/hpu-gaudi2.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index f81e690e835b..ac19638e67de 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,13 +39,14 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.15.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest + image: vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice env: PT_HPU_LAZY_MODE: 0 + TORCHINDUCTOR_COMPILE_THREADS: 1 TEST_LIST: | test_accelerator.py test_autotuning.py @@ -103,7 +104,7 @@ jobs: - name: Check container state run: | ldd --version - hl-smi + hl-smi -L python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" @@ -128,7 +129,7 @@ jobs: unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests export PT_HPU_LAZY_MODE=${PT_HPU_LAZY_MODE} + export TORCHINDUCTOR_COMPILE_THREADS=${TORCHINDUCTOR_COMPILE_THREADS} TEST_LIST=$(echo "$TEST_LIST" | awk 'NF{printf "%s%s", (NR>1 ? " or " : ""), $0} END{if (NR>1) print ""}') echo "TEST_LIST ${TEST_LIST}" - echo "PT_HPU_LAZY_MODE ${PT_HPU_LAZY_MODE}" pytest --verbose unit/ -k "${TEST_LIST}" From 6e5d58d24843864ea72fbf78846123f4c30d935a Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 13 Aug 2024 16:36:22 -0700 Subject: [PATCH 03/13] OptimizedLinear updates (#5791) This is a refresh of of `OptimizedLinear` with the following features to improve performance and usability: * More efficient sharing of base weights using `all_gather_into_tensor` * Flattened sharded weights * Selectively offload frozen weights to cpu * `deepspeed.linear.Init` that allows injecting OptimizedLinear during model construction (similar to zero.Init) * Support for load state dict directly in OptimizedLinear, this allows loading HF model weights correctly into sharded params * Various bug fixes for the LoRA implementation introduced previously * Several new unit tests Builds on-top of @RezaYazdaniAminabadi's previous FP8 updates (#5764) to support dense model fp8 quantization. Example usage of this to fine-tune llama-3.1-405B on a single node: https://github.com/Snowflake-Labs/snowflake-arctic/tree/main/training/llama3.1 --------- Co-authored-by: Reza Yazdani Co-authored-by: Reza Yazdani <152926435+sfc-gh-reyazda@users.noreply.github.com> --- deepspeed/linear/__init__.py | 1 + deepspeed/linear/config.py | 12 ++- deepspeed/linear/context_manager.py | 90 ++++++++++++++++ deepspeed/linear/optimized_linear.py | 144 ++++++++++++++++++------- deepspeed/linear/quantization.py | 12 ++- deepspeed/ops/fp_quantizer/quantize.py | 7 ++ deepspeed/runtime/engine.py | 46 +++++++- tests/unit/linear/test_ctx.py | 106 ++++++++++++++++++ 8 files changed, 379 insertions(+), 39 deletions(-) create mode 100644 deepspeed/linear/context_manager.py create mode 100644 tests/unit/linear/test_ctx.py diff --git a/deepspeed/linear/__init__.py b/deepspeed/linear/__init__.py index a27f1c3eaee7..9931a95a0a40 100644 --- a/deepspeed/linear/__init__.py +++ b/deepspeed/linear/__init__.py @@ -5,3 +5,4 @@ from .optimized_linear import OptimizedLinear from .config import LoRAConfig, QuantizationConfig +from .context_manager import Init, init_lora diff --git a/deepspeed/linear/config.py b/deepspeed/linear/config.py index ae9050a3c92b..2632ce7de9c4 100644 --- a/deepspeed/linear/config.py +++ b/deepspeed/linear/config.py @@ -3,7 +3,8 @@ # DeepSpeed Team -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import List @dataclass @@ -17,10 +18,19 @@ class LoRAConfig: base_weight_sharding (int): The degree to which the base weights are sharded, should typically be set to the data-parallel world size to maximize the memory reduction benefits. Defaults to 1, which means this feature is disabled. + offload (bool): offload frozen parameters to cpu when not in use + offload_ratio (float): ratio of parameters to offload to cpu when not in use + delay_lora_init (bool): initialize lora parameters at time of model init or allow manual init later + target_mods (str): target module names to apply LoRA to, defaults to llama-3.1 arch """ lora_r: int = 64 lora_alpha: float = 16. base_weight_sharding: int = 1 + offload: bool = False + offload_ratio: float = 0.0 + delay_lora_init: bool = False + target_mods: List[str] = field( + default_factory=lambda: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']) @dataclass diff --git a/deepspeed/linear/context_manager.py b/deepspeed/linear/context_manager.py new file mode 100644 index 000000000000..204fa0fe9c1d --- /dev/null +++ b/deepspeed/linear/context_manager.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .optimized_linear import LoRAOptimizedLinear, OptimizedLinear + +import torch + +try: + import transformers +except ImportError: + transformers = None + + +def init_lora(model): + model.requires_grad_(False) + for m in model.modules(): + if isinstance(m, LoRAOptimizedLinear): + m.init_lora() + + +class Init(object): + """ + Init context wrapper similar in style to zero.Init. Allows for injecting OptimizedLinear during model + construction which will shard base weights and reduce overall memory usage during model init. Primarily + useful when initializing a model via transformers.AutoModelForCausalLM. + + Example usage: + lora_config = deepspeed.linear.LoRAConfig(..) + quant_config = deepspeed.linear.QuantizationConfig(..) + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-405B") + + """ + + def __init__(self, lora_config=None, quant_config=None): + self._orig_nn_linear = torch.nn.Linear + self._orig_causallm_pretrained = None + if transformers != None: + self._orig_causallm_pretrained = transformers.AutoModelForCausalLM.from_pretrained + self._orig_causallm_config = transformers.AutoModelForCausalLM.from_config + self.lora_config = lora_config + self.quant_config = quant_config + self._post_init_complete = False + + def __enter__(self): + + class OptLinearWrapper: + _orig_nn_linear = self._orig_nn_linear + _lora_config = self.lora_config + _quant_config = self.quant_config + + def __new__(self, *args, **kwargs): + self._lora_config.delay_lora_init = True + kwargs['lora_config'] = self._lora_config + kwargs['quantization_config'] = self._quant_config + kwargs['linear_cls'] = self._orig_nn_linear + return OptimizedLinear(*args, **kwargs) + + def _model_init(model): + if self.lora_config != None: + init_lora(model) + self._post_init_complete = True + return model + + # ensures non-lora params are frozen and lora weights are initialized + def from_pretrained(*args, **kwargs): + model = self._orig_causallm_pretrained(*args, **kwargs) + return _model_init(model) + + def from_config(*args, **kwargs): + model = self._orig_causallm_config(*args, **kwargs) + return _model_init(model) + + torch.nn.Linear = OptLinearWrapper + if transformers != None: + transformers.AutoModelForCausalLM.from_pretrained = from_pretrained + transformers.AutoModelForCausalLM.from_config = from_config + + def __exit__(self, *args, **kwargs): + torch.nn.Linear = self._orig_nn_linear + if not self._post_init_complete: + print('WARNING: For some reason LoRA modules are not initialized, this is usually done automatically ' + 'if using transformers via (AutoModelForCausalLM from_pretrained/from_config). ' + 'You must call `init_lora` on each module in order to use DeepSpeed LoRA, otherwise ' + 'you will error out during runtime.') + else: + transformers.AutoModelForCausalLM.from_pretrained = self._orig_causallm_pretrained + transformers.AutoModelForCausalLM.from_config = self._orig_causallm_config diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py index e982785a8122..3720196aa255 100644 --- a/deepspeed/linear/optimized_linear.py +++ b/deepspeed/linear/optimized_linear.py @@ -40,7 +40,9 @@ def __new__(self, bias: bool = False, lora_config: LoRAConfig = None, quantization_config: QuantizationConfig = None, - dtype=torch.bfloat16): + device=None, + dtype=torch.bfloat16, + linear_cls=nn.Linear): if quantization_config is not None and not is_dataclass(quantization_config): raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") @@ -48,7 +50,7 @@ def __new__(self, raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") if lora_config is None and quantization_config is None: # Everything disabled, fall back to normal nn.Linear - self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) + self = linear_cls(input_dim, output_dim, bias=bias, dtype=dtype, device=device) elif lora_config: # lora enabled, quantization may or may not be @@ -57,7 +59,9 @@ def __new__(self, bias=bias, lora_config=lora_config, quantization_config=quantization_config, - dtype=dtype) + dtype=dtype, + device=device, + linear_cls=linear_cls) elif quantization_config: # only quantization enabled, no lora @@ -78,57 +82,121 @@ def __init__(self, lora_config: LoRAConfig = None, quantization_config: QuantizationConfig = None, device=None, - dtype=torch.bfloat16): + dtype=torch.bfloat16, + linear_cls=nn.Linear): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.bias = bias self.lora_config = lora_config self.quantization_config = quantization_config - device = get_accelerator().current_device_name() if device is None else device + self.device = get_accelerator().current_device_name() if device is None else device + self.linear_cls = linear_cls + self.dtype = dtype assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" - + assert not self.bias, "bias=True is not supported by LoRAOptimizedLinear" self.zero_shards = self.lora_config.base_weight_sharding self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) - w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype)) - torch.nn.init.xavier_uniform_(w) + if self.zero_shards > 1: + assert self.zero_shards == dist.get_world_size( + ), "base weight sharding is only supported across world size" + w = torch.nn.Parameter(torch.empty(self.output_dim * self.sharded_weight_size, dtype=dtype), + requires_grad=False) + else: + w = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=dtype), requires_grad=False) + torch.nn.init.xavier_uniform_(w.reshape(self.sharded_weight_size, self.output_dim)) if self.quantization_config is not None: assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" - self.base_weight = QuantizedParameter(w, quantization_config=quantization_config) + self.weight = QuantizedParameter(w, quantization_config=quantization_config) else: - self.base_weight = w + self.weight = w + + self.disabled = False + self._initialized = False + if not self.lora_config.delay_lora_init: + self.init_lora() + + def disable(self): + self.disabled = True + self.weight = torch.nn.Parameter(torch.empty((self.output_dim, self.input_dim), dtype=self.dtype), + requires_grad=False) + + def init_lora(self): + if self.disabled: + return + + if self.quantization_config is not None: + # ensure quant-param wasn't stripped, in some cases transformers will do this during model init + if not isinstance(self.weight, QuantizedParameter): + self.weight = QuantizedParameter(self.weight, quantization_config=self.quantization_config) + + self._initialized = True + self.weight.requires_grad = False - self.base_weight.requires_grad = False + # Mark base weight to prevent broadcast and ensure proper offload behavior + self.weight.ds_optim_param = True + + self.lora_scaling_factor = self.lora_config.lora_alpha / self.lora_config.lora_r - # Use RS lora for now. - self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r) # Keeping lora weights in bf16 precision for ease of training. - self.lora_weight_1 = nn.Linear(self.input_dim, - self.lora_config.lora_r, - bias=self.bias, - device=device, - dtype=dtype) - self.lora_weight_2 = nn.Linear(self.lora_config.lora_r, - self.output_dim, - bias=self.bias, - device=device, - dtype=dtype) + self.lora_weight_1 = self.linear_cls(self.input_dim, + self.lora_config.lora_r, + bias=self.bias, + device=self.device, + dtype=self.dtype) + self.lora_weight_2 = self.linear_cls(self.lora_config.lora_r, + self.output_dim, + bias=self.bias, + device=self.device, + dtype=self.dtype) + + # initialize "A" with kaiming uniform and "B" with zeros following this + # https://github.com/huggingface/peft/blob/62122b5add8d6892f70c82eaef2147a6ba33b90b/src/peft/tuners/lora/layer.py#L155 + nn.init.kaiming_uniform_(self.lora_weight_1.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_weight_2.weight) self.lora_weight_1.weight.requires_grad = True self.lora_weight_2.weight.requires_grad = True + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + if not any([target in prefix for target in self.lora_config.target_mods]): + # module does not match any target_mods, we must revert to normal nn.Linear via disable + self.disable() + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs) + + if self.zero_shards > 1: + if not dist.is_initialized(): + raise RuntimeError( + "attempting to use optimized linear base weight sharding but torch-distributed is not initialized, please init first." + ) + rank = dist.get_rank() + shape_local = self.output_dim * self.sharded_weight_size + base_weight_name = f"{prefix}weight" + incoming_param = state_dict[base_weight_name] + state_dict[base_weight_name] = incoming_param.flatten().narrow(0, rank * shape_local, shape_local) + + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + def full_weight(self): - # This assumes weights are evenly sharded across gpus. which might not be correct. - # in that case, we should flatten before all_gather. - local_weight = self.base_weight.dequantized() if isinstance(self.base_weight, - QuantizedParameter) else self.base_weight - tensor_list = [ - torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype) - for _ in range(self.zero_shards) - ] - dist.all_gather(tensor_list, local_weight) - weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1)) - return weight + base_weight = self.weight + if getattr(base_weight, 'ds_offload', False): + # move to gpu so we can dequant and all-gather + assert base_weight.device == torch.device('cpu'), \ + f"expected base weight on cpu but found {base_weight.device}" + base_weight.offload(revert=True) + local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight + base_weight.offload() + else: + local_weight = base_weight.dequantized() if isinstance(base_weight, QuantizedParameter) else base_weight + + tensor_out = torch.empty(self.output_dim * self.input_dim, + dtype=local_weight.dtype, + device=local_weight.device) + dist.all_gather_into_tensor(tensor_out, local_weight) + return tensor_out.reshape(self.output_dim, self.input_dim) def linear_without_F_linear(self, input, weight): output = torch.mm(input.reshape(-1, input.shape[-1]), weight) @@ -136,14 +204,18 @@ def linear_without_F_linear(self, input, weight): return output def forward(self, input_tensor): + if self.disabled: + return F.linear(input_tensor, self.weight) + assert self._initialized, "init_lora was never called, please initialize before proceeding" + # Gather the sharded base weight if self.zero_shards > 1: with torch.no_grad(): base_weight = self.full_weight() elif self.quantization_config: - base_weight = self.base_weight.dequantized() + base_weight = self.weight.dequantized() else: - base_weight = self.base_weight + base_weight = self.weight base_weight_output = F.linear(input_tensor, base_weight) lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py index 8e4f23dfba89..70fabea845ba 100644 --- a/deepspeed/linear/quantization.py +++ b/deepspeed/linear/quantization.py @@ -75,6 +75,13 @@ def dequantized(self) -> torch.Tensor: q_mantisa_bits=self.quantization_config.mantissa_bits) return self.data + def offload(self, revert=False): + if getattr(self, 'ds_offload', False): + if revert: + self.data = self.to(get_accelerator().current_device_name()) + else: + self.data = self.to('cpu') + def __getstate__(self): state = self.__dict__ state["data"] = self.data @@ -104,7 +111,9 @@ def __copy__(self): return new_instance def cuda(self, device=None, non_blocking=False): - return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + device = "cuda" if device is None else device + self.quantizer.to(device, non_blocking=non_blocking) + return self.to(device, non_blocking=non_blocking) def to(self, *args, **kwargs): """ @@ -112,6 +121,7 @@ def to(self, *args, **kwargs): quantize it. """ tensor = super().to(*args, **kwargs) + self.quantizer.to(*args, **kwargs) self._ensure_quantized(tensor) return tensor diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 170954e0cf71..edd4ef57302c 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -91,6 +91,13 @@ def quantize(self, return out + def to(self, *args, **kwargs): + # Intermediate tensors may need to be moved to different devices + if hasattr(self, 'input_q'): + self.input_q = self.input_q.to(*args, **kwargs) + if hasattr(self, 'scale'): + self.scale = self.scale.to(*args, **kwargs) + def get_scales(self): return fp_quant_module.get_scales(self.scale, self.num_groups) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d40141132aaf..1c74c0c735a0 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -35,6 +35,8 @@ from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.bf16_optimizer import BF16_Optimizer +from deepspeed.linear.optimized_linear import LoRAOptimizedLinear + from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ @@ -326,6 +328,8 @@ def __init__(self, self.sparse_tensor_module_names.add(name + ".weight") logger.info("Will convert {} to sparse tensor during training".format(name)) + self._optimized_linear_offload_setup() + self.save_non_zero_checkpoint = False self.save_zero_checkpoint = False if not isinstance(self.optimizer, DeepSpeedZeRoOffload): @@ -363,6 +367,43 @@ def __init__(self, self._is_compiled = False + def _optimized_linear_offload_setup(self): + self.optimized_linear_base_weight_sharding = False + self.optimized_linear_lora_enabled = False + offload_ratio = None + for _, module in self.module.named_modules(): + if isinstance(module, LoRAOptimizedLinear): + self.optimized_linear_lora_enabled = True + offload_ratio = None + if offload_ratio is not None: + assert offload_ratio == module.lora_config.offload_ratio, \ + "all lora_config offload ratios should be the same across the model" + offload_ratio = module.lora_config.offload_ratio + if module.zero_shards > 1: + # set attr so checkpoint saving can handle BWS properly + self.optimized_linear_base_weight_sharding = True + + if offload_ratio is None: + # Nothing enabled, do nothing + return + + total_params = 0 + for _, p in self.module.named_parameters(): + if hasattr(p, 'ds_optim_param'): + total_params += p.numel() + + offload_limit = total_params * offload_ratio + logger.info(f'offloading {offload_ratio*100}% of eligible params, specifically {offload_limit} params') + total_offloaded = 0 + for _, p in self.module.named_parameters(): + if hasattr(p, 'ds_optim_param'): + if total_offloaded < offload_limit: + total_offloaded += p.numel() + p.ds_offload = True + p.offload() + else: + p.ds_offload = False + def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): self.optimizer.destroy() @@ -1054,9 +1095,12 @@ def _broadcast_model(self): def is_replicated(p): if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE: return False + elif hasattr(p, 'ds_optim_param'): + # do not broadcast OptimizedLinear parameters, they are unique per base weight shard + return False return True - for p in self.module.parameters(): + for n, p in self.module.named_parameters(): # Broadcast the model for different parameters if is_moe_param(p): if torch.is_tensor(p) and is_replicated(p): diff --git a/tests/unit/linear/test_ctx.py b/tests/unit/linear/test_ctx.py new file mode 100644 index 000000000000..e03d13fd6ce2 --- /dev/null +++ b/tests/unit/linear/test_ctx.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import deepspeed +import pytest +from unit.common import DistributedTest + +import deepspeed.comm as dist +from deepspeed.linear import LoRAConfig, init_lora +from deepspeed.linear.optimized_linear import LoRAOptimizedLinear +from unit.simple_model import random_dataloader, SimpleModel + +try: + import transformers +except ImportError: + transformers = None + +if transformers is None: + pytest.skip("transformers is required for this test", allow_module_level=True) + + +def injection_assert(model): + # pick out random linear that should have been replaced and initialized + q_proj = model.model.layers[1].self_attn.q_proj + + assert isinstance(q_proj, LoRAOptimizedLinear), "injection did not happen" + assert q_proj._initialized, "lora was not initialized properly" + assert isinstance(q_proj.lora_weight_1, torch.nn.Linear) + assert isinstance(q_proj.lora_weight_2, torch.nn.Linear) + + +class TestEngine(DistributedTest): + world_size = 2 + + def test_model(self): + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2) + quant_config = None + hidden_dim = 64 + nlayers = 4 + + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers) + + init_lora(model) + + model_norms = [model.linears[i].weight.norm().item() for i in range(nlayers)] + + ds_config = { + "train_batch_size": 2, + "steps_per_print": 1, + "bf16": { + "enabled": True + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": 1 + } + } + model, *_ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters()) + + engine_norms = [model.module.linears[i].weight.norm().item() for i in range(nlayers)] + + # Ensure that sharded weights are not broadcast during engine init + assert engine_norms == model_norms, f"{dist.get_rank()=} base weight norms are not the same after engine init, {engine_norms=} != {model_norms=}" + + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + +class TestInitTransformers(DistributedTest): + world_size = 2 + + def test_pretrained_init(self): + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2) + quant_config = None + + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = transformers.AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-Llama-3") + + injection_assert(model) + + def test_config_init(self): + lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=2) + quant_config = None + + config = transformers.AutoConfig.from_pretrained("llamafactory/tiny-random-Llama-3") + + with deepspeed.linear.Init(lora_config=lora_config, quant_config=quant_config): + model = transformers.AutoModelForCausalLM.from_config(config) + + injection_assert(model) From 0f2d485c273661a4cd9627bd4a0d2fe84fb66dc2 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Tue, 13 Aug 2024 21:10:17 -0400 Subject: [PATCH 04/13] Log operator warnings only in verbose mode (#5917) --- op_builder/evoformer_attn.py | 17 ++++++++++------ op_builder/fp_quantizer.py | 21 ++++++++++++-------- op_builder/inference_core_ops.py | 9 ++++++--- op_builder/inference_cutlass_builder.py | 9 ++++++--- op_builder/ragged_ops.py | 9 ++++++--- op_builder/ragged_utils.py | 9 ++++++--- op_builder/sparse_attn.py | 26 ++++++++++++++++--------- op_builder/spatial_inference.py | 6 ++++-- op_builder/transformer_inference.py | 9 ++++++--- 9 files changed, 75 insertions(+), 40 deletions(-) diff --git a/op_builder/evoformer_attn.py b/op_builder/evoformer_attn.py index 6e7721f94e01..af3aa7429775 100644 --- a/op_builder/evoformer_attn.py +++ b/op_builder/evoformer_attn.py @@ -41,18 +41,21 @@ def nvcc_args(self): args.append(f"-DGPU_ARCH={major}{minor}") return args - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile kernels") return False if self.cutlass_path is None: - self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH") + if verbose: + self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH") return False with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f: if '3.1.0' not in f.read(): - self.warning("Please use CUTLASS version >= 3.1.0") + if verbose: + self.warning("Please use CUTLASS version >= 3.1.0") return False cuda_okay = True if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda @@ -60,10 +63,12 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 7: - self.warning("Please use a GPU with compute capability >= 7.0") + if verbose: + self.warning("Please use a GPU with compute capability >= 7.0") cuda_okay = False if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("Please use CUDA 11+") + if verbose: + self.warning("Please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index c7d2e72b5408..40cf504c2c83 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -22,11 +22,12 @@ def __init__(self, name=None): def absolute_name(self): return f'deepspeed.ops.fp_quantizer.{self.NAME}_op' - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -35,17 +36,20 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 8: - self.warning("NVIDIA Inference is only supported on Ampere and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Ampere and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False try: import triton except ImportError: - self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels") + if verbose: + self.warning(f"please install triton==2.3.0 or 2.3.1 if you want to use the FP Quantizer Kernels") return False # triton 2.3.0 and 2.3.1 are okay and the only versions released in 2.3.x before 3.x was released @@ -59,9 +63,10 @@ def is_compatible(self, verbose=True): triton_mismatch = major != "2" or minor != "3" if triton_mismatch: - self.warning( - f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels" - ) + if verbose: + self.warning( + f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.0 and 2.3.1 are known to be compatible with these kernels" + ) return False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index d1957f39d9a8..45e8628e669f 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -23,7 +23,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -32,11 +33,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/inference_cutlass_builder.py b/op_builder/inference_cutlass_builder.py index 51f7931d9435..fda6e74bbf6a 100644 --- a/op_builder/inference_cutlass_builder.py +++ b/op_builder/inference_cutlass_builder.py @@ -22,7 +22,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -31,11 +32,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py index ec7cab91885f..a4e365786a2b 100644 --- a/op_builder/ragged_ops.py +++ b/op_builder/ragged_ops.py @@ -23,7 +23,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -32,11 +33,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/ragged_utils.py b/op_builder/ragged_utils.py index 89450e1fd30d..a855f072af8c 100755 --- a/op_builder/ragged_utils.py +++ b/op_builder/ragged_utils.py @@ -23,7 +23,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -32,11 +33,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/sparse_attn.py b/op_builder/sparse_attn.py index 188d257ff4ef..2385adc8fe9c 100644 --- a/op_builder/sparse_attn.py +++ b/op_builder/sparse_attn.py @@ -27,45 +27,51 @@ def sources(self): def cxx_args(self): return ['-O2', '-fopenmp'] - def is_compatible(self, verbose=True): + def is_compatible(self, verbose=False): # Check to see if llvm and cmake are installed since they are dependencies #required_commands = ['llvm-config|llvm-config-9', 'cmake'] #command_status = list(map(self.command_exists, required_commands)) #deps_compatible = all(command_status) if self.is_rocm_pytorch(): - self.warning(f'{self.NAME} is not compatible with ROCM') + if verbose: + self.warning(f'{self.NAME} is not compatible with ROCM') return False try: import torch except ImportError: - self.warning(f"unable to import torch, please install it first") + if verbose: + self.warning(f"unable to import torch, please install it first") return False # torch-cpu will not have a cuda version if torch.version.cuda is None: cuda_compatible = False - self.warning(f"{self.NAME} cuda is not available from torch") + if verbose: + self.warning(f"{self.NAME} cuda is not available from torch") else: major, minor = torch.version.cuda.split('.')[:2] cuda_compatible = (int(major) == 10 and int(minor) >= 1) or (int(major) >= 11) if not cuda_compatible: - self.warning(f"{self.NAME} requires CUDA version 10.1+") + if verbose: + self.warning(f"{self.NAME} requires CUDA version 10.1+") TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) torch_compatible = (TORCH_MAJOR == 1 and TORCH_MINOR >= 5) if not torch_compatible: - self.warning( - f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}') + if verbose: + self.warning( + f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}') try: import triton except ImportError: # auto-install of triton is broken on some systems, reverting to manual install for now # see this issue: https://github.com/microsoft/DeepSpeed/issues/1710 - self.warning(f"please install triton==1.0.0 if you want to use sparse attention") + if verbose: + self.warning(f"please install triton==1.0.0 if you want to use sparse attention") return False if pkg_version: @@ -76,7 +82,9 @@ def is_compatible(self, verbose=True): triton_mismatch = installed_triton != "1.0.0" if triton_mismatch: - self.warning(f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible") + if verbose: + self.warning( + f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible") return False return super().is_compatible(verbose) and torch_compatible and cuda_compatible diff --git a/op_builder/spatial_inference.py b/op_builder/spatial_inference.py index 59caf57f938d..8a6b36cce0b0 100644 --- a/op_builder/spatial_inference.py +++ b/op_builder/spatial_inference.py @@ -21,7 +21,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -31,7 +32,8 @@ def is_compatible(self, verbose=True): cuda_capability = torch.cuda.get_device_properties(0).major if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 5ee902289448..88b77499cc0e 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -21,7 +21,8 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + if verbose: + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True @@ -30,11 +31,13 @@ def is_compatible(self, verbose=True): torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + if verbose: + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + if verbose: + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay From 862aff37a58d777aa55107c1db2160c38f427fae Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Wed, 14 Aug 2024 10:50:40 +0800 Subject: [PATCH 05/13] Use `torch.nan_to_num` replace numpy wrapper one (#5877) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/eigenvalue.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py index df63854dd1ca..36300eb904dd 100755 --- a/deepspeed/runtime/eigenvalue.py +++ b/deepspeed/runtime/eigenvalue.py @@ -7,6 +7,7 @@ from deepspeed.utils import log_dist import numpy as np import logging +from deepspeed.utils.torch import required_torch_version class Eigenvalue(object): @@ -36,12 +37,15 @@ def __init__(self, ranks=[0]) # Replace all nan/pos-inf/neg-inf to zero - # TODO: Pytorch new version may add this function, replace this one by then. def nan_to_num(self, x): - device = x.device - x = x.cpu().numpy() - x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) - return torch.from_numpy(x).to(device) + if required_torch_version(min_version=1.8): + return torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + else: + # Fallback to numpy based implementation for backwards-compatibility with PyTorch 1.7 or older versions. + device = x.device + x = x.cpu().numpy() + x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) + return torch.from_numpy(x).to(device) def normalize(self, v): norm_squared = self.inner_product(v, v) From 24d1d86ba294c7794dce5848d8caac55438fc5ad Mon Sep 17 00:00:00 2001 From: YiSheng5 Date: Wed, 14 Aug 2024 19:19:21 +0800 Subject: [PATCH 06/13] [Zero2] Reduce the unnecessary all-reduce when tensor size is 0. (#5868) When running for Zero2, the reduce_bucket_size we set is not large enough, the self.elements_in_ipg_bucket will be 0, then in function average_tensor the input is the tensor with size=0 https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py#L1372 use reduce_scatter can be WA https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py#L1066 if user uses the reduce_scatter=false, in function gradient_reduction_w_predivide will meet the unnecessary all-reduce with tensor size is 0. https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py#L974 This pr is to add the judgement to reduce this unnecessary all-reduce. Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/zero/stage_1_and_2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 461ad719536a..57e80911d645 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -972,6 +972,8 @@ def print_rank_0(self, message): logger.info(message) def gradient_reduction_w_predivide(self, tensor): + if tensor.size().numel() == 0: + return tensor dp_world_size = dist.get_world_size(group=self.dp_process_group) From 051c993310f0c521ee0758d717168edcdf60c77e Mon Sep 17 00:00:00 2001 From: Raza Sikander <54884406+raza-sikander@users.noreply.github.com> Date: Wed, 14 Aug 2024 20:36:38 +0530 Subject: [PATCH 07/13] Update container version for Gaudi2 CI (#5937) Update version to 1.17.0 from 1.16.2 Co-authored-by: Shaik Raza Sikander --- .github/workflows/hpu-gaudi2.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index ac19638e67de..0272829e8286 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,7 +39,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.16.2/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest + image: vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice From e3177de666f9aca1a9e16c57ffe38f22ebf54d38 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 14 Aug 2024 07:24:15 -0400 Subject: [PATCH 08/13] Fix missing ds_id bug (#5824) Fix #5495 - Fix missing ds_id bug by copying solution from #5193 (credit to @getinglxf) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/zero/stage3.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 3ac6987e9c22..9b7645261eae 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -821,10 +821,14 @@ def _create_fp32_partitions(self): for i, tensor in enumerate(self.fp16_partitioned_groups_flat): num_elements = self.fp16_partitioned_groups_flat_numel[i] + ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0]) + ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1]) + ds_id = ds_id_begin + '_' + ds_id_end # a partition of the fp32 master weights that will be updated by this process if self._swappable_optimizer_subgroup(i): self.fp32_partitioned_groups_flat.append(torch.Tensor()) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id nvme_memory_usage += (fp32_element_size * num_elements) num_swappable_partitions += 1 @@ -861,11 +865,9 @@ def _create_fp32_partitions(self): else: self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( self.device).clone().float().detach()) + self.fp32_partitioned_groups_flat[i].ds_id = ds_id self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it - ds_id_begin = str(self.fp16_partitioned_groups_flat_id[i][0]) - ds_id_end = str(self.fp16_partitioned_groups_flat_id[i][-1]) - self.fp32_partitioned_groups_flat[i].ds_id = ds_id_begin + '_' + ds_id_end if len(swappable_fp32_tensors) > 0: self.optimizer_swapper.initialize_parameters(parameters=swappable_fp32_tensors, From f994fb2c4e418e56c26c9f23372d0e334d0c2ccb Mon Sep 17 00:00:00 2001 From: Xi Yang Date: Wed, 14 Aug 2024 11:08:41 -0400 Subject: [PATCH 09/13] Update LR scheduler configuration (#5846) This PR is based on https://github.com/microsoft/DeepSpeed/issues/5726. The current lr scheduler initialization always prioritize config over manual defined scheduler in the code. However, the optimizer initialization implementation prioritize manual defined optimizer over config. This PR aims to make initialization behavior for both optimizer and lr scheduler consistent where if lr scheduler is defined in the code, then it will overwrite config. --------- Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/engine.py | 22 ++-- tests/unit/runtime/test_ds_initialize.py | 129 +++++++++++++++++++++++ 2 files changed, 140 insertions(+), 11 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 1c74c0c735a0..d2839a8f5d7c 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -306,7 +306,7 @@ def __init__(self, if has_optimizer: self._configure_optimizer(optimizer, model_parameters) - self._configure_lr_scheduler(lr_scheduler) + self._configure_lr_scheduler() self._report_progress(0) elif self.zero_optimization(): # no optim selected but zero is enabled @@ -943,19 +943,19 @@ def _optimizer_has_ckpt_event_prologue(self): def _optimizer_has_ckpt_event_epilogue(self): return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue') - def _configure_lr_scheduler(self, client_lr_scheduler): - # First check for scheduler in json configuration - lr_scheduler = self._scheduler_from_config(self.optimizer) - if lr_scheduler: - log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0]) - self.lr_scheduler = lr_scheduler - else: - if isinstance(client_lr_scheduler, Callable): + def _configure_lr_scheduler(self): + if self.client_lr_scheduler: + if isinstance(self.client_lr_scheduler, Callable): log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0]) - self.lr_scheduler = client_lr_scheduler(self.basic_optimizer) + self.lr_scheduler = self.client_lr_scheduler(self.basic_optimizer) else: log_dist('DeepSpeed using client LR scheduler', ranks=[0]) - self.lr_scheduler = client_lr_scheduler + self.lr_scheduler = self.client_lr_scheduler + else: + # load lr scheduler from json configuration if lr scheduler is not defined and passed in + lr_scheduler = self._scheduler_from_config(self.optimizer) + log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0]) + self.lr_scheduler = lr_scheduler log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 9ff99f169f7a..a30f81cedde9 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -305,3 +305,132 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler: assert ds_lr_scheduler == client_scheduler else: assert isinstance(ds_lr_scheduler, LambdaLR) + + +@pytest.mark.parametrize("scheduler_type", [None, _LRScheduler, Callable]) +class TestClientLrSchedulerInit(DistributedTest): + world_size = 1 + + def test_same_lrscheler_and_callable(self, scheduler_type): + """ + Expect behavior + + if lr scheduler is defined in code and passed into initialize as arg, + it will be used even this is a lr scheduler has been defined in config. + + Initialize lr scheduler from config when no lr scheduler is defined in code. + """ + + def _my_lambda(epoch): + return epoch // 10 + + def _lr_scheduler_callable(optimizer) -> _LRScheduler: + return LambdaLR(optimizer, _my_lambda) + + config_dict = {'train_batch_size': 1} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + if scheduler_type is None: + config_dict['scheduler'] = {'type': WARMUP_LR, 'params': {}} + client_scheduler = None + elif scheduler_type == _LRScheduler: + client_scheduler = LambdaLR(client_optimizer, _my_lambda) + else: + client_scheduler = _lr_scheduler_callable + + _, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=list(model.parameters()), + optimizer=client_optimizer, + lr_scheduler=client_scheduler) + if scheduler_type is None: + # in this case, we initialize from config + assert not isinstance(ds_lr_scheduler, LambdaLR) + assert isinstance(ds_lr_scheduler, WarmupLR) + else: + # in this case, we initialize from passed-in scheduler + assert isinstance(ds_lr_scheduler, LambdaLR) + assert not isinstance(ds_lr_scheduler, WarmupLR) + + def test_diff_lrscheler_and_callable(self, scheduler_type): + """ + In this test, + the LambdaLR will be used for lrscheduler type + and the StepLR will be used for callable type + """ + + from torch.optim.lr_scheduler import StepLR + + def _my_lambda(epoch): + return epoch // 10 + + def _lr_scheduler_callable(optimizer) -> _LRScheduler: + return StepLR(optimizer, step_size=30) + + config_dict = {'train_batch_size': 1} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + if scheduler_type is None: + config_dict['scheduler'] = {'type': WARMUP_LR, 'params': {}} + client_scheduler = None + elif scheduler_type == _LRScheduler: + client_scheduler = LambdaLR(client_optimizer, _my_lambda) + else: + client_scheduler = _lr_scheduler_callable + + _, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=list(model.parameters()), + optimizer=client_optimizer, + lr_scheduler=client_scheduler) + if scheduler_type is None: + assert isinstance(ds_lr_scheduler, WarmupLR) + elif scheduler_type == _LRScheduler: + assert isinstance(ds_lr_scheduler, LambdaLR) + else: + # callable + assert isinstance(ds_lr_scheduler, StepLR) + + def test_diff_lrscheler_and_callable_onecyclelr_steplr(self, scheduler_type): + + from deepspeed.runtime.lr_schedules import OneCycle, ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR + from torch.optim.lr_scheduler import OneCycleLR, StepLR + + def _lr_scheduler_callable(optimizer) -> _LRScheduler: + return OneCycleLR(optimizer, max_lr=0.01, total_steps=200) + + config_dict = {'train_batch_size': 1} + + hidden_dim = 10 + model = SimpleModel(hidden_dim) + + client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + if scheduler_type is None: + config_dict['scheduler'] = {'type': ONE_CYCLE, 'params': {CYCLE_MIN_LR: 0, CYCLE_MAX_LR: 0.1}} + client_scheduler = None + elif scheduler_type == _LRScheduler: + client_scheduler = StepLR(client_optimizer, step_size=30) + else: + client_scheduler = _lr_scheduler_callable + + _, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=list(model.parameters()), + optimizer=client_optimizer, + lr_scheduler=client_scheduler) + if scheduler_type is None: + assert isinstance(ds_lr_scheduler, OneCycle) + elif scheduler_type == _LRScheduler: + assert isinstance(ds_lr_scheduler, StepLR) + else: + # callable + assert isinstance(ds_lr_scheduler, OneCycleLR) From a8d1b44d4ba50c8635a000d096f3c9a2df21be20 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Wed, 14 Aug 2024 23:21:55 +0300 Subject: [PATCH 10/13] HPUAccelerator: remove support in set_visible_devices_envs (#5929) The way deepspeed sets it is not correct with all HPU instances and may lead to incorrect behavior. Co-authored-by: Olatunji Ruwase --- accelerator/hpu_accelerator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 5c2e92c9ef69..485b205f3e67 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -298,7 +298,11 @@ def export_envs(self): return [] def visible_devices_envs(self): - return ['HABANA_VISIBLE_MODULES'] + # Current way deepspeed set this env var is not applicable with all HPU instances + # User has to follow instructions in: + # https://docs.habana.ai/en/latest/PyTorch/Reference/PT_Multiple_Tenants_on_HPU/Multiple_Workloads_Single_Docker.html + # keeping CUDA_VISIBLE_DEVICES + return ['CUDA_VISIBLE_DEVICES'] #['HABANA_VISIBLE_MODULES'] def set_visible_devices_envs(self, current_env, local_accelerator_ids): for env in self.visible_devices_envs(): From 19b01e1d6089103d1ae397f2a60a178a18f63ab5 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 15 Aug 2024 06:43:37 +0800 Subject: [PATCH 11/13] Add accelerator setup guides (#5827) This document provide a places to hold accelerator setup guides. It is intend to be a single place to lookup installation guides for different accelerators. Currently CPU and XPU setup guides are added to this document and could be extended to other accelerators. --------- Co-authored-by: Liangliang Ma <1906710196@qq.com> Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- .../accelerator-abstraction-interface.md | 67 +-------- docs/_tutorials/accelerator-setup-guide.md | 134 ++++++++++++++++++ 2 files changed, 137 insertions(+), 64 deletions(-) create mode 100644 docs/_tutorials/accelerator-setup-guide.md diff --git a/docs/_tutorials/accelerator-abstraction-interface.md b/docs/_tutorials/accelerator-abstraction-interface.md index 88a43236ce9d..d7c153638c0d 100644 --- a/docs/_tutorials/accelerator-abstraction-interface.md +++ b/docs/_tutorials/accelerator-abstraction-interface.md @@ -12,7 +12,6 @@ tags: getting-started - [Tensor operations](#tensor-operations) - [Communication backend](#communication-backend) - [Run DeepSpeed model on different accelerators](#run-deepspeed-model-on-different-accelerators) -- [Run DeepSpeed model on CPU](#run-deepspeed-model-on-cpu) - [Implement new accelerator extension](#implement-new-accelerator-extension) # Introduction @@ -79,69 +78,9 @@ torch.distributed.init_process_group(get_accelerator().communication_backend_nam ``` # Run DeepSpeed model on different accelerators -Once a model is ported with DeepSpeed Accelerator Abstraction Interface, we can run this model on different accelerators using an extension to DeepSpeed. DeepSpeed checks whether a certain extension is installed in the environment to decide whether to use the Accelerator backend in that extension. For example, if we wish to run a model on Intel GPU, we can install _Intel Extension for DeepSpeed_ following the instructions in the following [link](https://github.com/intel/intel-extension-for-deepspeed/) - -After the extension is installed, install DeepSpeed and run the model. The model will be running on top of DeepSpeed. Because DeepSpeed installation is also accelerator related, it is recommended to install DeepSpeed accelerator extension before installing DeepSpeed. - -`CUDA_Accelerator` is the default accelerator in DeepSpeed. If no other DeepSpeed accelerator extension is installed, `CUDA_Accelerator` will be used. - -When running a model on different accelerators in a cloud environment, the recommended practice is to provision an environment for each accelerator in a different env with tools such as _anaconda/miniconda/virtualenv_. When running models on different Accelerator, load the env accordingly. - -Note that different accelerator may have different 'flavor' of float16 or bfloat16. So it is recommended to make the model configurable for both float16 and bfloat16, in that way model code does not need to be changed when running on different accelerators. - -# Run DeepSpeed model on CPU -DeepSpeed support using CPU as accelerator. DeepSpeed model using DeepSpeed Accelerator Abstraction Interface could run on CPU without change to model code. DeepSpeed decide whether _Intel Extension for PyTorch_ is installed in the environment. If this packaged is installed, DeepSpeed will use CPU as accelerator. Otherwise CUDA device will be used as accelerator. - -To run DeepSpeed model on CPU, use the following steps to prepare environment: - -``` -python -m pip install intel_extension_for_pytorch -python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu -git clone https://github.com/oneapi-src/oneCCL -cd oneCCL -mkdir build -cd build -cmake .. -make -make install -``` - -Before run CPU workload, we need to source oneCCL environment variables -``` -source /build/_install/env/setvars.sh -``` - -After environment is prepared, we can launch DeepSpeed inference with the following command -``` -deepspeed --bind_cores_to_rank -``` - -This command would launch number of workers equal to number of CPU sockets on the system. Currently DeepSpeed support running inference model with AutoTP on top of CPU. The argument `--bind_cores_to_rank` distribute CPU cores on the system evenly among workers, to allow each worker running on a dedicated set of CPU cores. - -On CPU system, there might be daemon process that periodically activate which would increase variance of each worker. One practice is leave a couple of cores for daemon process using `--bind-core-list` argument: - -``` -deepspeed --bind_cores_to_rank --bind_core_list 0-51,56-107 -``` - -The command above leave 4 cores on each socket to daemon process (assume two sockets, each socket has 56 cores). - -We can also set an arbitrary number of workers. Unlike GPU, CPU cores on host can be further divided into subgroups. When this number is not set, DeepSpeed would detect number of NUMA nodes on the system and launch one worker for each NUMA node. - -``` -deepspeed --num_accelerators 4 --bind_cores_to_rank -``` - -Launching DeepSpeed model on multiple CPU nodes is similar to other accelerators. We need to specify `impi` as launcher and specify `--bind_cores_to_rank` for better core binding. Also specify `slots` number according to number of CPU sockets in host file. - -``` -# hostfile content should follow the format -# worker-1-hostname slots=<#sockets> -# worker-2-hostname slots=<#sockets> -# ... - -deepspeed --hostfile= --bind_cores_to_rank --launcher impi --master_addr -``` +[Accelerator Setup Guide](accelerator-setup-guide.md) provides a guide on how to setup different accelerators for DeepSpeed. It also comes with simple example how to run deepspeed for different accelerators. The following guides are provided: +1. Run DeepSpeed model on CPU +2. Run DeepSpeed model on XPU # Implement new accelerator extension It is possible to implement a new DeepSpeed accelerator extension to support new accelerator in DeepSpeed. An example to follow is _[Intel Extension For DeepSpeed](https://github.com/intel/intel-extension-for-deepspeed/)_. An accelerator extension contains the following components: diff --git a/docs/_tutorials/accelerator-setup-guide.md b/docs/_tutorials/accelerator-setup-guide.md new file mode 100644 index 000000000000..cf2d01d2b25c --- /dev/null +++ b/docs/_tutorials/accelerator-setup-guide.md @@ -0,0 +1,134 @@ +--- +title: DeepSpeed Accelerator Setup Guides +tags: getting-started +--- + +# Contents +- [Contents](#contents) +- [Introduction](#introduction) +- [Intel Architecture (IA) CPU](#intel-architecture-ia-cpu) +- [Intel XPU](#intel-xpu) + +# Introduction +DeepSpeed supports different accelerators from different companies. Setup steps to run DeepSpeed on certain accelerators might be different. This guide allows user to lookup setup instructions for the accelerator family and hardware they are using. + +# Intel Architecture (IA) CPU +DeepSpeed supports CPU with Intel Architecture instruction set. It is recommended to have the CPU support at least AVX2 instruction set and recommend AMX instruction set. + +DeepSpeed has been verified on the following CPU processors: +* 4th Gen Intel® Xeon® Scalarable Processors +* 5th Gen Intel® Xeon® Scalarable Processors +* 6th Gen Intel® Xeon® Scalarable Processors + +## Installation steps for Intel Architecture CPU +To install DeepSpeed on Intel Architecture CPU, use the following steps: +1. Install gcc compiler +DeepSpeed requires gcc-9 or above to build kernels on Intel Architecture CPU, install gcc-9 or above. + +2. Install numactl +DeepSpeed use `numactl` for fine grain CPU core allocation for load-balancing, install numactl on your system. +For example, on Ubuntu system, use the following command: +`sudo apt-get install numactl` + +3. Install PyTorch +`pip install torch` + +4. Install DeepSpeed +`pip install deepspeed` + +## How to launch DeepSpeed on Intel Architecture CPU +DeepSpeed can launch on Intel Architecture CPU with default deepspeed command. However, for compute intensive workloads, Intel Architecture CPU works best when each worker process runs on different set of physical CPU cores, so worker process does not compete CPU cores with each other. To bind cores to each worker (rank), use the following command line switch for better performance. +``` +deepspeed --bind_cores_to_rank +``` +This switch would automatically detect the number of CPU NUMA node on the host, launch the same number of workers, and bind each worker to cores/memory of a different NUMA node. This improves performance by ensuring workers do not interfere with each other, and that all memory allocation is from local memory. + +If a user wishes to have more control on the number of workers and specific cores that can be used by the workload, user can use the following command line switches. +``` +deepspeed --num_accelerators --bind_cores_to_rank --bind_core_list +``` +For example: +``` +deepspeed --num_accelerators 4 --bind_cores_to_rank --bind_core_list <0-27,32-59> inference.py +``` +This would start 4 workers for the workload. The core list range will be divided evenly between 4 workers, with worker 0 take 0-13, worker 1, take 14-27, worker 2 take 32-45, and worker 3 take 46-59. Core 28-31,60-63 are left out because there might be some background process running on the system, leaving some idle cores will reduce performance jitting and straggler effect. + +Launching DeepSpeed model on multiple CPU nodes is similar to other accelerators. We need to specify `impi` as launcher and specify `--bind_cores_to_rank` for better core binding. Also specify `slots` number according to number of CPU sockets in host file. + +``` +# hostfile content should follow the format +# worker-1-hostname slots=<#sockets> +# worker-2-hostname slots=<#sockets> +# ... + +deepspeed --hostfile= --bind_cores_to_rank --launcher impi --master_addr +``` + +## Install with Intel Extension for PyTorch and oneCCL +Although not mandatory, Intel Extension for PyTorch and Intel oneCCL provide better optimizations for LLM models. Intel oneCCL also provide optimization when running LLM model on multi-node. To use DeepSpeed with Intel Extension for PyTorch and oneCCL, use the following steps: +1. Install Intel Extension for PyTorch. This is suggested if you want to get better LLM inference performance on CPU. +`pip install intel-extension-for-pytorch` + +The following steps are to install oneCCL binding for PyTorch. This is suggested if you are running DeepSpeed on multiple CPU node, for better communication performance. On single node with multiple CPU socket, these steps are not needed. + +2. Install Intel oneCCL binding for PyTorch +`python -m pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable-cpu` + +3. Install Intel oneCCL, this will be used to build direct oneCCL kernels (CCLBackend kernels) +``` +pip install oneccl-devel +pip install impi-devel +``` +Then set the environment variables for Intel oneCCL (assuming using conda environment). +``` +export CPATH=${CONDA_PREFIX}/include:$CPATH +export CCL_ROOT=${CONDA_PREFIX} +export I_MPI_ROOT=${CONDA_PREFIX} +export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib/ccl/cpu:${CONDA_PREFIX}/lib/libfabric:${CONDA_PREFIX}/lib +``` + +## Optimize LLM inference with Intel Extension for PyTorch +Intel Extension for PyTorch compatible with DeepSpeed AutoTP tensor parallel inference. It allows CPU inference to benefit from both DeepSpeed Automatic Tensor Parallelism, and LLM optimizations of Intel Extension for PyTorch. To use Intel Extension for PyTorch, after calling deepspeed.init_inference, call +``` +ipex_model = ipex.llm.optimize(deepspeed_model) +``` +to get model optimzied by Intel Extension for PyTorch. + +## More example for using DeepSpeed with Intel Extension for PyTorch on Intel Architecture CPU +Refer to https://github.com/intel/intel-extension-for-pytorch/tree/main/examples/cpu/inference/python/llm for more extensive guide. + +# Intel XPU +DeepSpeed XPU accelerator supports Intel® Data Center GPU Max Series. + +DeepSpeed has been verified on the following GPU products: +* Intel® Data Center GPU Max 1100 +* Intel® Data Center GPU Max 1550 + +## Installation steps for Intel XPU +To install DeepSpeed on Intel XPU, use the following steps: +1. Install oneAPI base toolkit \ +The Intel® oneAPI Base Toolkit (Base Kit) is a core set of tools and libraries, including an DPC++/C++ Compiler for building Deepspeed XPU kernels like fusedAdam and CPUAdam, high performance computation libraries demanded by IPEX, etc. +For easy download, usage and more details, check [Intel oneAPI base-toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html). +2. Install PyTorch, Intel extension for pytorch, Intel oneCCL Bindings for PyTorch. These packages are required in `xpu_accelerator` for torch functionality and performance, also communication backend on Intel platform. The recommended installation reference: +https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu. + +3. Install DeepSpeed \ +`pip install deepspeed` + +## How to use DeepSpeed on Intel XPU +DeepSpeed can be launched on Intel XPU with deepspeed launch command. Before that, user needs activate the oneAPI environment by: \ +`source /setvars.sh` + +To validate the XPU availability and if the XPU accelerator is correctly chosen, here is an example: +``` +$ python +>>> import torch; print('torch:', torch.__version__) +torch: 2.3.0 +>>> import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available()) +XPU available: True +>>> from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name) +accelerator: xpu +``` + +## More example for using DeepSpeed on Intel XPU +Refer to https://github.com/intel/intel-extension-for-pytorch/tree/release/xpu/2.1.40/examples/gpu/inference/python/llm for more extensive guide. From 6eed634eda502300b702f7a80c23f24aea08ed29 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Thu, 15 Aug 2024 02:38:45 +0300 Subject: [PATCH 12/13] Z3: optimizations for grad norm calculation and gradient clipping (#5504) This PR add the below functionality: 1. complete_grad_norm_calculation_for_cpu_offload: move total_norm to CPU, as expected device in such case is CPU.. 2. repalce get_global_norm() with torch.linalg.norm for better performance. 3. unscale_and_clip_grads: replace clipping based on if statement to use torch.clamp for better performance. change (3) is taken from https://github.com/microsoft/DeepSpeed/pull/5547 (which was closed) --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Co-authored-by: Liran Bachar --- deepspeed/runtime/zero/stage3.py | 10 +++++----- tests/unit/runtime/zero/test_zero_offloadpp.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 9b7645261eae..b0a3ab778f2a 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -15,7 +15,7 @@ from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce -from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -1413,7 +1413,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm - return total_norm + return total_norm.cpu() @instrument_w_nvtx def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: @@ -2028,7 +2028,7 @@ def step(self, closure=None): return norm_groups = self._get_norm_groups() - scaled_global_grad_norm = get_global_norm(norm_list=norm_groups) + scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups)) # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.loss_scale @@ -2112,8 +2112,8 @@ def unscale_and_clip_grads(self, sub_group_id, total_norm): if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale + clip = torch.clamp(clip, min=1.0) + combined_scale = clip * self.loss_scale self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale) diff --git a/tests/unit/runtime/zero/test_zero_offloadpp.py b/tests/unit/runtime/zero/test_zero_offloadpp.py index 5bfec399e19f..8ae99e2237e2 100644 --- a/tests/unit/runtime/zero/test_zero_offloadpp.py +++ b/tests/unit/runtime/zero/test_zero_offloadpp.py @@ -43,6 +43,7 @@ def test(self, h_dim: int, n_layers: int) -> None: config_dict = { "train_batch_size": 256, "steps_per_print": 1, + "gradient_clipping": 1.0, "optimizer": { "type": "Adam", "params": { From 4ba49ddad817fc5241867b08677ec91b2d3070cf Mon Sep 17 00:00:00 2001 From: Liangliang Ma Date: Thu, 15 Aug 2024 07:54:53 +0800 Subject: [PATCH 13/13] Update xpu-max1100.yml with new config and add some tests (#5668) This PR: 1.Change the container 2.Update the software version (align with docker compiler) 3. Add some tests --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- .github/workflows/xpu-max1100.yml | 36 +++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/xpu-max1100.yml b/.github/workflows/xpu-max1100.yml index c5a23fe3f53f..1042db100a21 100644 --- a/.github/workflows/xpu-max1100.yml +++ b/.github/workflows/xpu-max1100.yml @@ -36,38 +36,36 @@ jobs: unit-tests: runs-on: [self-hosted, intel, xpu] container: - image: intel/intel-extension-for-pytorch:2.1.30-xpu + image: intel/oneapi-basekit:2024.1.1-devel-ubuntu22.04 ports: - 80 options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL steps: - uses: actions/checkout@v4 - - name: Check container state - shell: bash - run: | - ldd --version - python -c "import torch; print('torch:', torch.__version__, torch)" - python -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())" - - - name: Install deepspeed + - name: Install prerequisite run: | - pip install py-cpuinfo + apt-get update + apt-get install clinfo libaio-dev python3-pip -y + pip install torch==2.1.0.post2 -f https://developer.intel.com/ipex-whl-stable-xpu + pip install intel-extension-for-pytorch==2.1.30+xpu -f https://developer.intel.com/ipex-whl-stable-xpu + pip install intel-extension-for-pytorch-deepspeed==2.1.30 -f https://developer.intel.com/ipex-whl-stable-xpu + pip install oneccl_bind_pt==2.1.300+xpu -f https://developer.intel.com/ipex-whl-stable-xpu + pip install torchvision==0.16.0.post2 -f https://developer.intel.com/ipex-whl-stable-xpu + pip install py-cpuinfo numpy==1.26 pip install .[dev,autotuning] - ds_report - python -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)" - - name: Python environment + - name: Check container state run: | + ldd --version + ds_report + python3 -c "import torch; print('torch:', torch.__version__, torch)" + python3 -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())" + python3 -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)" pip list - name: Unit tests run: | - pip install pytest pytest-timeout tabulate tensorboard wandb - export ONEAPI_ROOT=/opt/intel/oneapi/redist - export FI_PROVIDER_PATH=$ONEAPI_ROOT/opt/mpi/libfabric/lib/prov - export LD_LIBRARY_PATH=$ONEAPI_ROOT/opt/mpi/libfabric/lib:$LD_LIBRARY_PATH - export LD_LIBRARY_PATH=$ONEAPI_ROOT/lib:$LD_LIBRARY_PATH cd tests/unit pytest --verbose accelerator/* pytest --verbose autotuning/* @@ -75,8 +73,10 @@ jobs: pytest --verbose checkpoint/test_moe_checkpoint.py pytest --verbose checkpoint/test_shared_weights.py pytest --verbose launcher/test_ds_arguments.py launcher/test_run.py + pytest --verbose model_parallelism/* pytest --verbose moe/test_moe_tp.py pytest --verbose monitor/* + pytest --verbose utils/* pytest --verbose runtime/test_ds_config_model.py pytest --verbose runtime/pipe/test_pipe_schedule.py pytest --verbose runtime/zero/test_zero_config.py