Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add heuristics for intializing FP8 params #1300

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,47 @@ def test_pyt_autocast(
assert x.grad.dtype == model_dtype
assert op.weight.grad.dtype == model_dtype

@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("heuristic", (None, "performance", "memory"))
def test_fp8_heuristics(
self,
*,
heuristic: Optional[str],
size: int = 16,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
) -> None:
"""Test with FP8 heuristics"""

# Construct model
with te.fp8_model_init(heuristic=heuristic):
op = te_ops.Linear(size, size, bias=False, device=device, dtype=dtype)

# Check FP8 param
assert isinstance(op.weight, Float8Tensor)
if heuristic == "performance" or heuristic is None:
assert op.weight._transpose is not None
assert not op.weight._transpose_invalid
if heuristic == "memory":
assert op.weight._transpose is None

# Training loop
optim = torch.optim.SGD(op.parameters(), lr=1e-4)
for _ in range(3):
optim.zero_grad()
x = torch.randn((size, size), dtype=dtype, device=device, requires_grad=True)
y = op(x)
y.square().sum().backward()
optim.step()

# Check FP8 param
assert isinstance(op.weight, Float8Tensor)
if heuristic == "performance" or heuristic is None:
assert op.weight._transpose is not None
assert not op.weight._transpose_invalid
if heuristic == "memory":
assert op.weight._transpose is None


class TestBasicOps:
"""Tests for individual operations"""
Expand Down
24 changes: 20 additions & 4 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class FP8GlobalStateManager:
IS_FIRST_FP8_MODULE = False
FP8_GRAPH_CAPTURING = False
FP8_AUTOCAST_DEPTH = 0
FP8_HEURISTIC = None
global_amax_buffer = {}
global_amax_history_buffer = {}
global_scale_buffer = {}
Expand Down Expand Up @@ -265,6 +266,11 @@ def fp8_graph_capturing(cls) -> bool:
"""Is CUDA graph capture under way?"""
return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing()

@classmethod
def fp8_heuristic(cls) -> Optional[str]:
"""Heuristic for FP8 data format"""
return cls.FP8_HEURISTIC

@classmethod
def is_first_fp8_module(cls):
"""Returns `True` only the first time when called multiple
Expand Down Expand Up @@ -486,15 +492,17 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:


@contextmanager
def fp8_model_init(enabled: bool = True) -> None:
"""
Context manager for FP8 initialization of parameters.
def fp8_model_init(
enabled: bool = True,
heuristic: Optional[str] = None,
) -> None:
"""Context manager for FP8 initialization of parameters.

Example usage:

.. code-block:: python

with fp8_model_init(enabled=True):
with fp8_model_init():
model = transformer_engine.pytorch.Linear(768, 768)

Parameters
Expand All @@ -512,13 +520,21 @@ def fp8_model_init(enabled: bool = True) -> None:
* LoRA-like fine-tuning, where the main parameters of the model do not change.

This functionality is *EXPERIMENTAL*.
heuristic: string, optional
Heuristic for FP8 data format. Supported options are
"performance" (maximize runtime performance, default)
and "memory" (minimize memory usage).

"""
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
_fp8_heuristic = FP8GlobalStateManager.FP8_HEURISTIC
FP8GlobalStateManager.FP8_PARAMETERS = enabled
FP8GlobalStateManager.FP8_HEURISTIC = heuristic
try:
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
FP8GlobalStateManager.FP8_HEURISTIC = _fp8_heuristic


@contextmanager
Expand Down
12 changes: 10 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,17 +939,25 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
# Dummy buffer to avoid overwriting amax history
dummy_amax = torch.empty(
(1, 1),
dtype=torch.float32,
device=param.device,
) # Dummy buffer to avoid overwriting amax history
)

# Decide whether to store FP8 transpose
with_transpose_cache = torch.is_grad_enabled()
if with_transpose_cache and FP8GlobalStateManager.fp8_heuristic() == "memory":
with_transpose_cache = False

# Cast param to FP8
param = Float8Tensor.to_float8(
param,
fp8_meta=self.fp8_meta,
fp8_meta_index=fp8_meta_index,
amax=dummy_amax,
with_transpose_cache=torch.is_grad_enabled(),
with_transpose_cache=with_transpose_cache,
)

# Redo parameter wrap in case we broke it above
Expand Down
13 changes: 11 additions & 2 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,27 @@ def reset_parameters(self) -> None:

# Cast to FP8 if needed
if self._with_fp8_parameters:

# Dummy buffer to avoid overwriting amax history
dummy_amax = torch.empty(
(1, 1),
dtype=torch.float32,
device=self.device,
) # Dummy buffer to avoid overwriting amax history
)

# Decide whether to store FP8 transpose
with_transpose_cache = torch.is_grad_enabled()
if with_transpose_cache and FP8GlobalStateManager.fp8_heuristic() == "memory":
with_transpose_cache = False

# Cast to FP8
weight = Float8Tensor.to_float8(
weight,
fp8_meta=self.get_fp8_meta("param"),
fp8_meta_forward=True,
fp8_meta_index=0,
amax=dummy_amax,
with_transpose_cache=torch.is_grad_enabled(),
with_transpose_cache=with_transpose_cache,
)

# Save updated parameter
Expand Down
Loading