Skip to content

Commit

Permalink
Using regular layernorm for ONNX and TS export (NVIDIA#1634)
Browse files Browse the repository at this point in the history
* Using regular layernorm for ONNX and TS export

Signed-off-by: Boris Fomitchev <[email protected]>

* Added unit test for export

Signed-off-by: Boris Fomitchev <[email protected]>

* Tests modified to not depend on onnxruntime

Signed-off-by: Boris Fomitchev <[email protected]>

* moved tests inside the class

Signed-off-by: Boris Fomitchev <[email protected]>

* remove white spaces from empty lines

---------

Signed-off-by: Boris Fomitchev <[email protected]>
Co-authored-by: Masaki Kozuki <[email protected]>
  • Loading branch information
borisfom and crcrpar authored Apr 27, 2023
1 parent f2d6f29 commit 85e9edd
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 7 deletions.
8 changes: 4 additions & 4 deletions apex/normalization/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def reset_parameters(self):
init.zeros_(self.bias)

def forward(self, input):
if not input.is_cuda:
if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
if self.elementwise_affine:
return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)
Expand Down Expand Up @@ -379,7 +379,7 @@ def reset_parameters(self):
init.ones_(self.weight)

def forward(self, input):
if not input.is_cuda:
if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)

if self.elementwise_affine:
Expand Down Expand Up @@ -409,7 +409,7 @@ def __init__(self, normalized_shape, eps=1e-5, **kwargs):

def forward(self, input: torch.Tensor):
# NOTE (mkozuki): CPU path is here mainly for unittest sake.
if not input.is_cuda:
if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps)

Expand All @@ -432,6 +432,6 @@ def __init__(self, normalized_shape, eps=1e-5, **kwargs):
def forward(self, input: torch.Tensor):
# NOTE (mkozuki): CPU path is here mainly for unittest sake.
# TODO Manual RMS Norm Implementation Here
if not input.is_cuda:
if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps)
36 changes: 33 additions & 3 deletions tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch

from apex.normalization import FusedLayerNorm
from apex.normalization import FusedRMSNorm
from apex.normalization import MixedFusedLayerNorm
Expand Down Expand Up @@ -233,7 +232,6 @@ def test_autocast_fused_layer_norm(self, dtype, elementwise_affine):

tols = {'rtol': None, 'atol': None} if dtype == torch.half else bf16_bwd_thresholds
torch.testing.assert_close(native_x.grad, fused_x.grad, **tols, check_dtype=False)

@common_utils.parametrize(
"dtype, elementwise_affine",
list(product(autocast_dtypes, (True, False)))
Expand Down Expand Up @@ -266,8 +264,40 @@ def test_autocast_fused_rms_norm(self, dtype, elementwise_affine):
tols = {'rtol': 1e-3, 'atol': 1e-3} if dtype == torch.half else bf16_bwd_thresholds
torch.testing.assert_close(native_x.grad.cuda(), fused_x.grad, **tols, check_dtype=False)

def _verify_export(self, fused, fused_x):
# check that export() is working
onnx_str = torch.onnx.export_to_pretty_string(fused, (fused_x,),
input_names=['x_in'],
)
assert 'x_in' in onnx_str
assert 'ReduceMean' in onnx_str

def test_rms_export(self):
batch_size = 16
normalized_shape = [32, 16]
fused = FusedRMSNorm(
normalized_shape=normalized_shape, elementwise_affine=True
).cuda()
fused_m = MixedFusedRMSNorm(
normalized_shape=normalized_shape, elementwise_affine=True
).cuda()
native_x, fused_x = _prep_inputs(batch_size, normalized_shape, torch.float32)
self._verify_export(fused, fused_x)
self._verify_export(fused_m, fused_x)

def test_layer_norm_export(self):
batch_size = 16
normalized_shape = [32, 16]
fused = FusedLayerNorm(
normalized_shape=normalized_shape, elementwise_affine=True
).cuda()
fused_m = MixedFusedLayerNorm(
normalized_shape=normalized_shape, elementwise_affine=True
).cuda()
native_x, fused_x = _prep_inputs(batch_size, normalized_shape, torch.float32)
self._verify_export(fused, fused_x)
self._verify_export(fused_m, fused_x)

instantiate_device_type_tests(TestFusedLayerNorm, globals(), only_for=("cuda",))

if __name__ == "__main__":
common_utils.run_tests()

0 comments on commit 85e9edd

Please sign in to comment.