diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1d91683ae4..cb398d8aa8 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -429,6 +429,47 @@ def test_pyt_autocast( assert x.grad.dtype == model_dtype assert op.weight.grad.dtype == model_dtype + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.parametrize("heuristic", (None, "performance", "memory")) + def test_fp8_heuristics( + self, + *, + heuristic: Optional[str], + size: int = 16, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + ) -> None: + """Test with FP8 heuristics""" + + # Construct model + with te.fp8_model_init(heuristic=heuristic): + op = te_ops.Linear(size, size, bias=False, device=device, dtype=dtype) + + # Check FP8 param + assert isinstance(op.weight, Float8Tensor) + if heuristic == "performance" or heuristic is None: + assert op.weight._transpose is not None + assert not op.weight._transpose_invalid + if heuristic == "memory": + assert op.weight._transpose is None + + # Training loop + optim = torch.optim.SGD(op.parameters(), lr=1e-4) + for _ in range(3): + optim.zero_grad() + x = torch.randn((size, size), dtype=dtype, device=device, requires_grad=True) + y = op(x) + y.square().sum().backward() + optim.step() + + # Check FP8 param + assert isinstance(op.weight, Float8Tensor) + if heuristic == "performance" or heuristic is None: + assert op.weight._transpose is not None + assert not op.weight._transpose_invalid + if heuristic == "memory": + assert op.weight._transpose is None + class TestBasicOps: """Tests for individual operations""" diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 76679eb064..dfdb3ad933 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -78,6 +78,7 @@ class FP8GlobalStateManager: IS_FIRST_FP8_MODULE = False FP8_GRAPH_CAPTURING = False FP8_AUTOCAST_DEPTH = 0 + FP8_HEURISTIC = None global_amax_buffer = {} global_amax_history_buffer = {} global_scale_buffer = {} @@ -265,6 +266,11 @@ def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() + @classmethod + def fp8_heuristic(cls) -> Optional[str]: + """Heuristic for FP8 data format""" + return cls.FP8_HEURISTIC + @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple @@ -486,15 +492,17 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager -def fp8_model_init(enabled: bool = True) -> None: - """ - Context manager for FP8 initialization of parameters. +def fp8_model_init( + enabled: bool = True, + heuristic: Optional[str] = None, +) -> None: + """Context manager for FP8 initialization of parameters. Example usage: .. code-block:: python - with fp8_model_init(enabled=True): + with fp8_model_init(): model = transformer_engine.pytorch.Linear(768, 768) Parameters @@ -512,13 +520,21 @@ def fp8_model_init(enabled: bool = True) -> None: * LoRA-like fine-tuning, where the main parameters of the model do not change. This functionality is *EXPERIMENTAL*. + heuristic: string, optional + Heuristic for FP8 data format. Supported options are + "performance" (maximize runtime performance, default) + and "memory" (minimize memory usage). + """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + _fp8_heuristic = FP8GlobalStateManager.FP8_HEURISTIC FP8GlobalStateManager.FP8_PARAMETERS = enabled + FP8GlobalStateManager.FP8_HEURISTIC = heuristic try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters + FP8GlobalStateManager.FP8_HEURISTIC = _fp8_heuristic @contextmanager diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a5fcf50465..c17102efb5 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -939,17 +939,25 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # If primary weights are in fp8, wrap the parameter as Float8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index if self.primary_weights_in_fp8 and fp8_meta_index is not None: + # Dummy buffer to avoid overwriting amax history dummy_amax = torch.empty( (1, 1), dtype=torch.float32, device=param.device, - ) # Dummy buffer to avoid overwriting amax history + ) + + # Decide whether to store FP8 transpose + with_transpose_cache = torch.is_grad_enabled() + if with_transpose_cache and FP8GlobalStateManager.fp8_heuristic() == "memory": + with_transpose_cache = False + + # Cast param to FP8 param = Float8Tensor.to_float8( param, fp8_meta=self.fp8_meta, fp8_meta_index=fp8_meta_index, amax=dummy_amax, - with_transpose_cache=torch.is_grad_enabled(), + with_transpose_cache=with_transpose_cache, ) # Redo parameter wrap in case we broke it above diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 859b1ba1d7..7b40d5637a 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -289,18 +289,27 @@ def reset_parameters(self) -> None: # Cast to FP8 if needed if self._with_fp8_parameters: + + # Dummy buffer to avoid overwriting amax history dummy_amax = torch.empty( (1, 1), dtype=torch.float32, device=self.device, - ) # Dummy buffer to avoid overwriting amax history + ) + + # Decide whether to store FP8 transpose + with_transpose_cache = torch.is_grad_enabled() + if with_transpose_cache and FP8GlobalStateManager.fp8_heuristic() == "memory": + with_transpose_cache = False + + # Cast to FP8 weight = Float8Tensor.to_float8( weight, fp8_meta=self.get_fp8_meta("param"), fp8_meta_forward=True, fp8_meta_index=0, amax=dummy_amax, - with_transpose_cache=torch.is_grad_enabled(), + with_transpose_cache=with_transpose_cache, ) # Save updated parameter