Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix training #774

Merged
merged 37 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4284963
WIP fix training loss
michaelbenayoun Jan 30, 2025
38ef8f1
WIP fix training loss
michaelbenayoun Jan 31, 2025
0b7f6d5
Small fix on prepare_dataloder
michaelbenayoun Jan 31, 2025
39ab969
Fix grad norm
michaelbenayoun Jan 31, 2025
a7f7695
Fix training loop
michaelbenayoun Feb 6, 2025
b67f265
Temporary change of this file to debug thing
michaelbenayoun Feb 11, 2025
fe231bf
Fix mixed precision
michaelbenayoun Feb 11, 2025
26baa0f
Restore script for tutorial
michaelbenayoun Feb 11, 2025
a1412e2
Fix style
michaelbenayoun Feb 11, 2025
39e54e1
Fix style
michaelbenayoun Feb 11, 2025
0639f91
Fix dtype casting issue with TP
michaelbenayoun Feb 12, 2025
e1b2591
[WIP]
michaelbenayoun Feb 19, 2025
3da78b4
[WIP]
michaelbenayoun Feb 20, 2025
11375d7
[WIP]
michaelbenayoun Feb 20, 2025
48f0116
Remove LinearWithAsyncCommunicationFixed
michaelbenayoun Feb 20, 2025
d97c835
Add docstring for parallel_cross_entropy
michaelbenayoun Feb 20, 2025
56fd4c3
Cleanup
michaelbenayoun Feb 20, 2025
1b4080b
Remove empty spaces
michaelbenayoun Feb 21, 2025
fc34abc
Merge branch 'main' into align_for_training
michaelbenayoun Feb 21, 2025
e3623a0
Remove SDK 2.20 specifics
michaelbenayoun Feb 21, 2025
3aaa023
Styling
michaelbenayoun Feb 21, 2025
1807f71
Fixes
michaelbenayoun Feb 21, 2025
2c2a101
Remove unused comment
michaelbenayoun Feb 21, 2025
a59744d
Fix
michaelbenayoun Feb 21, 2025
ccb6960
Fix
michaelbenayoun Feb 24, 2025
3a5a168
[WIP] fix GQA QKV
michaelbenayoun Jan 10, 2025
8704584
[WIP] GQA checkpointing works, but output_proj does not work
michaelbenayoun Jan 15, 2025
d29ed89
Fix output_proj
michaelbenayoun Jan 15, 2025
08d58f8
Fix
michaelbenayoun Feb 24, 2025
e885d08
Fix
michaelbenayoun Feb 24, 2025
3ee8bda
Styling
michaelbenayoun Feb 24, 2025
031acce
Fix loss zeroing
michaelbenayoun Feb 26, 2025
1d83221
Fix test
michaelbenayoun Feb 26, 2025
701360b
Styling
michaelbenayoun Feb 26, 2025
412e16f
Fix teardown
michaelbenayoun Feb 26, 2025
b920479
Styling
michaelbenayoun Feb 26, 2025
2be6cbc
Apply suggestions
michaelbenayoun Feb 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 18 additions & 36 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
)
from .utils.misc import (
apply_activation_checkpointing,
create_patched_finfo,
create_patched_save_pretrained,
)
from .utils.operations import _xla_gather
Expand Down Expand Up @@ -203,7 +202,9 @@ def _prepare_data_loader_for_distributed(
distributed_dataloader._is_accelerate_prepared = True
return distributed_dataloader

def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optional[bool] = None):
def prepare_data_loader(
self, data_loader: DataLoader, device_placement: Optional[bool] = None, use_mp_device_loader: bool = False
):
force_drop_last = False
if self.state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
from neuronx_distributed import parallel_layers
Expand All @@ -224,11 +225,9 @@ def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optiona
data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last
)
# No need to wrap the dataloader if we are using pipeline parallelism.
if self.state.mp_plugin.pipeline_parallel_size == 1:
if use_mp_device_loader and self.state.mp_plugin.pipeline_parallel_size == 1:
data_loader = MpDeviceLoader(data_loader, self.device)
return data_loader
# TODO: fix that.
# return super().prepare_data_loader(data_loader, device_placement=device_placement)

def _prepare_optimizer_for_mp(self, optimizer: torch.optim.Optimizer, device_placement=None):
cpu_parameters_to_xla = collections.ChainMap(*self._model_cpu_parameters_to_xla.values())
Expand Down Expand Up @@ -329,19 +328,6 @@ def patch_model_for_neuron(
# Working on a copy for safety.
patching_specs = list(patching_specs)

mixed_precision_is_bf16 = self.state.mixed_precision == "bf16"
patched_finfo = create_patched_finfo(
xla_downcast_bf16=mixed_precision_is_bf16 and self.state.downcast_bfloat,
use_amp=mixed_precision_is_bf16 and self.state.autocast_backend is AutocastBackend.AMP,
xla_use_bf16=mixed_precision_is_bf16 and not self.state.downcast_bfloat,
)
patching_specs.append(
(
"forward",
DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))),
),
)

if isinstance(model, PreTrainedModel):
patching_specs.append(
(
Expand Down Expand Up @@ -459,6 +445,9 @@ def prepare_model(

model = self.patch_model_for_neuron(model)

if self.state.mixed_precision == "bf16":
model.to(torch.bfloat16)

# We do not want to use the cache, or output unused tensors as it would imply more communication that we do not
# need.
model.config.use_cache = False
Expand Down Expand Up @@ -529,24 +518,17 @@ def autocast(self, cache_enabled: bool = False, autocast_handler: Optional[Autoc
yield
autocast_context.__exit__(*sys.exc_info())

@requires_neuronx_distributed
def _prepare_clip_grad_norm(self, parameters, max_norm, norm_type: int = 2):
from neuronx_distributed.pipeline import NxDPPModel

self.unscale_gradients()
parameters = list(parameters)
for model in self._models:
model_parameters = model.local_parameters() if isinstance(model, NxDPPModel) else model.parameters()
if parameters == list(model_parameters) or self.zero_1:
for opt in self._optimizers:
# Under this setting, the gradient clipping will be deferred to the optimizer step.
# It will happen after the gradients have been reduced and before the optimizer step.
return opt.prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type)

def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM or self.zero_1:
return self._prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
return super().clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
def clip_grad_norm_(self, parameters, max_norm, norm_type=2, postpone_clipping_to_optimizer_step: bool = False):
if postpone_clipping_to_optimizer_step:
parameters = list(parameters)
if len(self._optimizers) > 1:
raise RuntimeError(
"Postponing gradient clipping to the optimizer step is not possible when multiple optimizer were "
"prepared by the NeuronAccelerator."
)
self._optimizers[0].prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
else:
return super().clip_grad_norm_(parameters, max_norm, norm_type=norm_type)

def _custom_save_state(
self,
Expand Down
23 changes: 11 additions & 12 deletions optimum/neuron/accelerate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
self.parameters = []
self.parameter_ids = {}
self.clip_grad_norm_to_perform = None
self.grad_norm = None
if self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
self.parameters = [p for group in self.optimizer.param_groups for p in group["params"]]
self.parameter_ids = {id(p) for p in self.parameters}
Expand All @@ -78,9 +79,7 @@ def load_state_dict(self, state_dict):
return super().load_state_dict(state_dict)

def prepare_clip_grad_norm(self, parameters, max_norm, norm_type=2):
parameter_ids = {id(p) for p in parameters}
if parameter_ids == self.parameter_ids or isinstance(self.optimizer, ZeroRedundancyOptimizer):
self.clip_grad_norm_to_perform = {"max_norm": max_norm, "norm_type": norm_type}
self.clip_grad_norm_to_perform = {"parameters": parameters, "max_norm": max_norm, "norm_type": norm_type}

@requires_neuronx_distributed
def step(self, closure=None):
Expand All @@ -100,22 +99,22 @@ def step(self, closure=None):
self.optimizer.max_norm = self.clip_grad_norm_to_perform["max_norm"]
else:
self.optimizer.grad_clipping = False
optimizer_args = {"closure": closure} if closure is not None else {}
self.optimizer.step(closure)
self.optimizer.step(closure=closure)
# Resetting everything.
self.optimizer.grad_clipping = False
self.clip_grad_norm_to_perform = None
elif self.accelerator_state.distributed_type is DistributedType.XLA:
optimizer_args = {"closure": closure} if closure is not None else {}
# By default barrier=False, but making sure it's the case here since we use ParalleLoader.
xm.optimizer_step(self.optimizer, optimizer_args=optimizer_args, barrier=False)
elif self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
elif (
self.accelerator_state.distributed_type is DistributedType.XLA
or self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM
):
if parallel_layers.parallel_state.get_data_parallel_size() > 1:
bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer))
if self.clip_grad_norm_to_perform is not None:
parallel_layers.clip_grad_norm(self.parameters, **self.clip_grad_norm_to_perform)
parameters = self.clip_grad_norm_to_perform.pop("parameters", None)
if parameters is not None:
self.grad_norm = parallel_layers.clip_grad_norm(parameters, **self.clip_grad_norm_to_perform)
self.clip_grad_norm_to_perform = None
self.optimizer.step()
self.optimizer.step(closure=closure)
elif self.scaler is not None:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer, closure)
Expand Down
11 changes: 1 addition & 10 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,7 @@ def __init__(

if self.distributed_type == DistributedType.XLA:
if mixed_precision == "bf16":
if autocast_backend is AutocastBackend.AMP:
self.downcast_bfloat = True
elif os.environ.get("ACCELERATE_DOWNCAST_BF16"):
os.environ["XLA_USE_BF16"] = str(0)
os.environ["XLA_DOWNCAST_BF16"] = str(1)
self.downcast_bfloat = True
else:
os.environ["XLA_USE_BF16"] = str(1)
os.environ["XLA_DOWNCAST_BF16"] = str(0)
self.downcast_bfloat = False
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"

if mp_plugin is None:
mp_plugin = ModelParallelismPlugin()
Expand Down
22 changes: 0 additions & 22 deletions optimum/neuron/accelerate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union

import torch
from transformers.modeling_utils import get_parameter_dtype

from ....utils import logging
from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere
Expand Down Expand Up @@ -67,27 +66,6 @@ def patch_accelerate_is_torch_xla_available():
_ORIG_TORCH_FINFO = torch.finfo


def create_patched_finfo(xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False):
def patched_finfo(dtype):
if xla_downcast_bf16 or use_amp or xla_use_bf16:
return _ORIG_TORCH_FINFO(torch.bfloat16)
return _ORIG_TORCH_FINFO(dtype)

return patched_finfo


def create_patched_get_parameter_dtype(
xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False
):
def patched_get_parameter_dtype(module):
dtype = get_parameter_dtype(module)
if xla_downcast_bf16 or use_amp or xla_use_bf16:
return torch.bfloat16
return dtype

return patched_get_parameter_dtype


@requires_neuronx_distributed
@requires_safetensors
def torch_xla_safe_save_file(
Expand Down
8 changes: 7 additions & 1 deletion optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def initialize(mod: GQAQKVColumnParallelLinear, proj_name: str, output_size: int
else:
# TODO: change kv heads.
maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
mod, f"weight_{proj_name}", linear_layer=fake_linear_mod
mod, proj_name, f"weight_{proj_name}", linear_layer=fake_linear_mod
)
del fake_linear_mod

Expand Down Expand Up @@ -678,6 +678,9 @@ def should_parallelize_layer_predicate_func(layer):
"num_attention_heads": None,
"num_key_value_heads": None,
"kv_size_multiplier": None,
"fuse_qkv": None,
"q_output_size_per_partition": None,
"kv_output_size_per_partition": None,
}
for mod in model.modules():
if isinstance(mod, OptimumGQAQKVColumnParallelLinear):
Expand All @@ -690,6 +693,9 @@ def should_parallelize_layer_predicate_func(layer):
"num_attention_heads": num_attention_heads,
"num_key_value_heads": num_key_value_heads,
"kv_size_multiplier": kv_size_multiplier,
"fuse_qkv": mod.fuse_qkv,
"q_output_size_per_partition": mod.q_output_size_per_partition,
"kv_output_size_per_partition": mod.kv_output_size_per_partition,
}
break

Expand Down
93 changes: 60 additions & 33 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,46 +134,73 @@ def consolidate_tensor_parallel_checkpoints(
for name in parameter_names:
# We need to handle the mapping between the GQA parameter names and the original names.
is_gqa_qkv_weight = name in gqa_qkv_names_to_original_names
is_fuse_qkv = gqa_qkv_metadata["fuse_qkv"]
if is_gqa_qkv_weight:
original_name = gqa_qkv_names_to_original_names[name]
weight_name = name.rsplit(".", maxsplit=1)[1]
if is_fuse_qkv:
original_names = [k for k, v in original_parameter_names_to_gqa_qkv_names.items() if v == name]
weight_names = [name.rsplit(".", maxsplit=1)[1] for name in original_names]
weight_names = ["weight_q", "weight_k", "weight_v"]
else:
original_names = [gqa_qkv_names_to_original_names[name]]
weight_names = [name.rsplit(".", maxsplit=1)[1]]
else:
original_name = name
weight_name = "" # Not needed.
original_names = [name]
weight_names = [""] # Not needed.

# For now all parameter metadatas are equal so it is enough to take the first element.
# This might not be the case anymore when `ParameterMetadata` uses slices.
sharded_metadata = sharded_metadatas[name]
if sharded_metadata.is_tied:
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous()
else:
# Ensure that all tensors are contiguous before concatenating or further processing
weights = [state_dict[name].contiguous() for state_dict in state_dicts]
tp_size = len(weights)

full_weight = (
torch.cat(
weights,
dim=sharded_metadata.partition_dim,
)
.to("cpu")
.contiguous()
) # Ensure the result is also contiguous

if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]:
for original_name, weight_name in zip(original_names, weight_names):
if sharded_metadata.is_tied:
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous()
else:
if is_fuse_qkv:
if weight_name == "weight_q":
s = slice(0, gqa_qkv_metadata["q_output_size_per_partition"])
elif weight_name == "weight_k":
s = slice(
gqa_qkv_metadata["q_output_size_per_partition"],
gqa_qkv_metadata["q_output_size_per_partition"]
+ gqa_qkv_metadata["kv_output_size_per_partition"],
)
elif weight_name == "weight_v":
s = slice(
gqa_qkv_metadata["q_output_size_per_partition"]
+ gqa_qkv_metadata["kv_output_size_per_partition"],
None,
)
else:
s = slice(None, None)
else:
s = slice(None, None)

# Ensure that all tensors are contiguous before concatenating or further processing
weights = [state_dict[name][s].contiguous() for state_dict in state_dicts]
tp_size = len(weights)

full_weight = (
torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone()
)
elif weight_name == "weight_q" or original_name in gqa_qkv_output_projections_names:
full_weight = create_gqa_query_or_output_projection_weight_from_full_weight(
full_weight,
tp_size,
gqa_qkv_metadata["num_attention_heads"],
gqa_qkv_metadata["num_key_value_heads"],
gqa_qkv_metadata["kv_size_multiplier"],
"query" if weight_name == "weight_q" else "output",
)
consolidated_state_dict[original_name] = full_weight
torch.cat(
weights,
dim=sharded_metadata.partition_dim,
)
.to("cpu")
.contiguous()
) # Ensure the result is also contiguous

if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]:
full_weight = (
torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone()
)
elif weight_name == "weight_q" or original_name in gqa_qkv_output_projections_names:
full_weight = create_gqa_query_or_output_projection_weight_from_full_weight(
full_weight,
tp_size,
gqa_qkv_metadata["num_attention_heads"],
gqa_qkv_metadata["num_key_value_heads"],
gqa_qkv_metadata["kv_size_multiplier"],
"query" if weight_name == "weight_q" else "output",
)
consolidated_state_dict[original_name] = full_weight

return consolidated_state_dict

Expand Down
Loading
Loading