Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Rename and clean up MXFP8 recipe class #1445

Draft
wants to merge 5 commits into
base: release_v2.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import torch.distributed as dist

from transformer_engine.common.recipe import (
MXFP8BlockScaling,
DelayedScaling,
Format,
MXFP8Recipe,
Recipe,
)
from run_layer_with_overlap import _compare_tensors
Expand All @@ -44,7 +44,7 @@ def quantization_recipe() -> Recipe:
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
if QUANTIZATION == "mxfp8":
return MXFP8BlockScaling()
return MXFP8Recipe()
return te.fp8.get_default_fp8_recipe()


Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/distributed/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
return transformer_engine.common.recipe.MXFP8Recipe(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ModelConfig:

fp8_recipes = [
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
recipe.MXFP8Recipe(),
]

# Supported data types
Expand Down Expand Up @@ -315,7 +315,7 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.mxfp8() and not mxfp8_available:
if fp8_recipe.is_mxfp8 and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

# Run model with different CUDA graph settings.
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
return transformer_engine.common.recipe.MXFP8Recipe(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
Expand Down
24 changes: 12 additions & 12 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq
mask_types = ["causal", "no_mask"]

fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.MXFP8Recipe(),
recipe.DelayedScaling(),
]

Expand Down Expand Up @@ -556,7 +556,7 @@ def _test_e2e_selective_recompute(
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
if recipe.is_mxfp8 and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

config = model_configs[model]
Expand Down Expand Up @@ -668,7 +668,7 @@ def test_gpt_full_activation_recompute(
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
if recipe.is_mxfp8 and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

config = model_configs[model]
Expand Down Expand Up @@ -1416,9 +1416,9 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f
if num_gemms > 1:
split_size = 1
if fp8:
if recipe.delayed():
if recipe.is_delayed_scaling:
split_size = 16
if recipe.mxfp8():
if recipe.is_mxfp8:
split_size = 128
m = config.seq_len // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
Expand Down Expand Up @@ -1463,10 +1463,10 @@ def test_grouped_linear_accuracy(
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
if recipe.is_mxfp8 and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and not recipe.is_delayed_scaling: # TODO(ksivamani): debug mismatches
pytest.skip("Grouped linear only supports FP8 delayed scaling")

config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
Expand Down Expand Up @@ -1648,10 +1648,10 @@ def test_padding_grouped_linear_accuracy(
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
if recipe.is_mxfp8 and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and not recipe.is_delayed_scaling: # TODO(ksivamani): debug mismatches
pytest.skip("Grouped linear only supports FP8 delayed scaling.")

config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
Expand Down Expand Up @@ -1860,7 +1860,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
if recipe.is_mxfp8 and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

config = model_configs[model]
Expand Down
21 changes: 10 additions & 11 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations
import warnings
from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple
from typing import Callable, ClassVar, Literal, NamedTuple, Optional, Union
from pydantic.dataclasses import dataclass


Expand Down Expand Up @@ -44,19 +44,16 @@ class Recipe:
Base recipe class.
"""

def mxfp8(self):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)

def delayed(self):
"""Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling)
# Whether the recipe is MXFP8
is_delayed_scaling: bool = False
# Whether the recipe is FP8 with delayed scaling
is_mxfp8: bool = False


@dataclass()
class DelayedScaling(Recipe):
"""
Use the delayed scaling factor strategy. Use scale factor from previous
Use the FP8 with delayed scaling factor strategy. Use scale factor from previous
iteration and record amax history of `amax_history_len` steps.

Parameters
Expand Down Expand Up @@ -141,6 +138,7 @@ def scaling_factor_compute(amax: Tensor,
reduce_amax: bool = True
fp8_dpa: bool = False
fp8_mha: bool = False
is_delayed_scaling: ClassVar[bool] = True

def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
Expand All @@ -162,9 +160,9 @@ def __repr__(self) -> str:


@dataclass()
class MXFP8BlockScaling(Recipe):
class MXFP8Recipe(Recipe):
"""
Use the current scaling factor strategy.
Use the MXFP8 1D strategy.

Parameters
----------
Expand All @@ -179,6 +177,7 @@ class MXFP8BlockScaling(Recipe):
fp8_format: Format = Format.E4M3
fp8_dpa: bool = False
fp8_mha: bool = False
is_mxfp8: ClassVar[bool] = True

def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
Expand Down
26 changes: 13 additions & 13 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling
from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8Recipe

from .constants import dist_group_type
from .utils import get_device_compute_capability
Expand Down Expand Up @@ -46,7 +46,7 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if get_device_compute_capability() >= (10, 0): # blackwell and above
return MXFP8BlockScaling()
return MXFP8Recipe()
return DelayedScaling()


Expand Down Expand Up @@ -211,7 +211,7 @@ def add_fp8_tensors_to_global_buffer(
wrapper. For non CG case, it's called from within the module.
"""

if fp8_meta["recipe"].mxfp8():
if not fp8_meta["recipe"].is_delayed_scaling:
return

# Every module must call this function exactly once since
Expand Down Expand Up @@ -414,7 +414,7 @@ def fp8_autocast_enter(
if enabled:
fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
assert fp8_available, reason_for_no_fp8
if isinstance(fp8_recipe, MXFP8BlockScaling):
if isinstance(fp8_recipe, MXFP8Recipe):
mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
assert mxfp8_available, reason_for_no_mxfp8

Expand All @@ -434,7 +434,7 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -
to ensure both forward steps are numerically same.
"""

if fp8_meta["recipe"].mxfp8():
if not fp8_meta["recipe"].is_delayed_scaling:
return

buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
Expand All @@ -460,7 +460,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non
1 forward for indentical numerical outputs.
"""

if fp8_meta["recipe"].mxfp8():
if not fp8_meta["recipe"].is_delayed_scaling:
return

# Store updated amaxes and scales from phase 1 post forward.
Expand All @@ -479,7 +479,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""

if fp8_meta["recipe"].mxfp8():
if not fp8_meta["recipe"].is_delayed_scaling:
return

fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
Expand Down Expand Up @@ -739,10 +739,10 @@ def create(
"""

cls = None
if recipe.delayed():
if recipe.is_delayed_scaling:
cls = DelayedScalingRecipeState
elif recipe.mxfp8():
cls = MXFP8BlockScalingRecipeState
elif recipe.is_mxfp8:
cls = MXFP8RecipeState
else:
raise ValueError("{recipe.__class__.__name__} is not supported")
return cls(
Expand Down Expand Up @@ -813,20 +813,20 @@ def make_quantizers(self) -> list:
]


class MXFP8BlockScalingRecipeState(RecipeState):
class MXFP8RecipeState(RecipeState):
"""Configuration for MXFP8 quantization.

MXFP8 quantization does not require state.

"""

recipe: MXFP8BlockScaling
recipe: MXFP8Recipe
mode: str
dtype: tex.DType

def __init__(
self,
recipe: MXFP8BlockScaling,
recipe: MXFP8Recipe,
*,
mode: str,
num_quantizers: int = 1,
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ._common import _ParameterInitMeta
from ..fp8 import (
MXFP8BlockScalingRecipeState,
MXFP8RecipeState,
DelayedScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
Expand Down Expand Up @@ -537,10 +537,10 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
# Return early if recipe state matches recipe
if self.fp8_meta_tensors_initialized:
recipe_state = self.fp8_meta[fp8_meta_tensor_key]
if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
if recipe.is_delayed_scaling and isinstance(recipe_state, DelayedScalingRecipeState):
self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
return
if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
if recipe.is_mxfp8 and isinstance(recipe_state, MXFP8RecipeState):
return

# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
Expand Down Expand Up @@ -635,7 +635,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor:
# Copy tensors to CPU and store
state = {}
state["recipe"] = self.fp8_meta["recipe"]
if state["recipe"].delayed():
if state["recipe"].is_delayed_scaling:
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
Expand Down Expand Up @@ -694,7 +694,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
dst.copy_(src, non_blocking=True)

# Load tensors
if self.fp8_meta["recipe"].delayed():
if self.fp8_meta["recipe"].is_delayed_scaling:
copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
Expand Down Expand Up @@ -811,7 +811,7 @@ def prepare_forward(
self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms)

if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].is_delayed_scaling:
assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def forward(
device = inp.device

# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("GroupedLinear does not yet support MXFP8")
if fp8 and not FP8GlobalStateManager.get_fp8_recipe().is_delayed_scaling:
raise NotImplementedError("GroupedLinear only supports FP8 delayed scaling")

# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
Expand Down
Loading
Loading