Skip to content

Commit

Permalink
Fix mixed precision
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 11, 2025
1 parent 57bdede commit c09e1b5
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 66 deletions.
17 changes: 3 additions & 14 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
)
from .utils.misc import (
apply_activation_checkpointing,
create_patched_finfo,
create_patched_save_pretrained,
)
from .utils.operations import _xla_gather
Expand Down Expand Up @@ -327,19 +326,6 @@ def patch_model_for_neuron(
# Working on a copy for safety.
patching_specs = list(patching_specs)

mixed_precision_is_bf16 = self.state.mixed_precision == "bf16"
patched_finfo = create_patched_finfo(
xla_downcast_bf16=mixed_precision_is_bf16 and self.state.downcast_bfloat,
use_amp=mixed_precision_is_bf16 and self.state.autocast_backend is AutocastBackend.AMP,
xla_use_bf16=mixed_precision_is_bf16 and not self.state.downcast_bfloat,
)
patching_specs.append(
(
"forward",
DynamicPatch(patch_within_function(("torch.finfo", patched_finfo))),
),
)

if isinstance(model, PreTrainedModel):
patching_specs.append(
(
Expand Down Expand Up @@ -457,6 +443,9 @@ def prepare_model(

model = self.patch_model_for_neuron(model)

if self.state.mixed_precision == "bf16":
model.to(torch.bfloat16)

# We do not want to use the cache, or output unused tensors as it would imply more communication that we do not
# need.
model.config.use_cache = False
Expand Down
11 changes: 1 addition & 10 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,7 @@ def __init__(

if self.distributed_type == DistributedType.XLA:
if mixed_precision == "bf16":
if autocast_backend is AutocastBackend.AMP:
self.downcast_bfloat = True
elif os.environ.get("ACCELERATE_DOWNCAST_BF16"):
os.environ["XLA_USE_BF16"] = str(0)
os.environ["XLA_DOWNCAST_BF16"] = str(1)
self.downcast_bfloat = True
else:
os.environ["XLA_USE_BF16"] = str(1)
os.environ["XLA_DOWNCAST_BF16"] = str(0)
self.downcast_bfloat = False
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"

if mp_plugin is None:
mp_plugin = ModelParallelismPlugin()
Expand Down
22 changes: 0 additions & 22 deletions optimum/neuron/accelerate/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union

import torch
from transformers.modeling_utils import get_parameter_dtype

from ....utils import logging
from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere
Expand Down Expand Up @@ -67,27 +66,6 @@ def patch_accelerate_is_torch_xla_available():
_ORIG_TORCH_FINFO = torch.finfo


def create_patched_finfo(xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False):
def patched_finfo(dtype):
if xla_downcast_bf16 or use_amp or xla_use_bf16:
return _ORIG_TORCH_FINFO(torch.bfloat16)
return _ORIG_TORCH_FINFO(dtype)

return patched_finfo


def create_patched_get_parameter_dtype(
xla_downcast_bf16: bool = False, use_amp: bool = False, xla_use_bf16: bool = False
):
def patched_get_parameter_dtype(module):
dtype = get_parameter_dtype(module)
if xla_downcast_bf16 or use_amp or xla_use_bf16:
return torch.bfloat16
return dtype

return patched_get_parameter_dtype


@requires_neuronx_distributed
@requires_safetensors
def torch_xla_safe_save_file(
Expand Down
20 changes: 0 additions & 20 deletions optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from .accelerate import NeuronAcceleratorState, NeuronPartialState
from .accelerate.utils import ModelParallelismPlugin, patch_accelerate_is_torch_xla_available
from .utils import is_main_worker
from .utils.misc import is_precompilation
from .utils.patching import Patcher, patch_within_function
from .utils.torch_xla_and_neuronx_initialization import set_neuron_cc_optlevel

Expand Down Expand Up @@ -181,25 +180,6 @@ def __post_init__(self):
async_save=self.async_save,
)

# If the user did not specify bf16=True but the flags are set, we set bf16=True.
# Without this we can fall in the case where XLA will compile the graph in bf16 with torch.finfo unpatched,
# leading to NaNs.
if not self.bf16 and (
os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1"
):
self.bf16 = True

if (
is_precompilation()
and self.bf16
and os.environ.get("XLA_USE_BF16", "0") == "0"
and os.environ.get("XLA_DOWNCAST_BF16", "0") == "0"
):
raise ValueError(
"bf16=True but both of the environment variables XLA_USE_BF16 and XLA_DOWNCAST_BF16 are not set. You "
"must set them manually when using `neuron_parallel_compile`."
)

if self.bf16 and self.half_precision_backend == "amp":
os.environ["ACCELERATE_USE_AMP"] = "true"
else:
Expand Down

0 comments on commit c09e1b5

Please sign in to comment.