Skip to content

Commit

Permalink
Update apex.mlp to use fp16 in autocast (NVIDIA#1477)
Browse files Browse the repository at this point in the history
* apex.amp migration to torch.cuda.amp

Signed-off-by: Masaki Kozuki <[email protected]>

* add autocast tests

Signed-off-by: Masaki Kozuki <[email protected]>

* split with and without autocast

Signed-off-by: Masaki Kozuki <[email protected]>

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar authored Sep 9, 2022
1 parent 19ca350 commit 81c9aca
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
3 changes: 3 additions & 0 deletions apex/_autocast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import torch


__all__ = ["_cast_if_autocast_enabled"]


def _get_autocast_dtypes() -> Sequence[torch.dtype]:
if torch.cuda.is_bf16_supported():
return [torch.half, torch.bfloat16]
Expand Down
20 changes: 13 additions & 7 deletions apex/mlp/mlp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from copy import copy
import math

import torch
from torch import nn

from apex._autocast_utils import _cast_if_autocast_enabled
import mlp_cuda
from .. import amp


class MlpFunction(torch.autograd.Function):
@staticmethod
Expand All @@ -21,8 +24,11 @@ def backward(ctx, grad_o):
del ctx.outputs
return (None, None, *grads)

# TODO(crcrpar): Should make this compatible with torch.cuda.amp
mlp_function = amp.half_function(MlpFunction.apply)

def mlp_function(bias, activation, *args):
autocast_args = _cast_if_autocast_enabled(bias, activation, *args)
return MlpFunction.apply(*autocast_args)


class MLP(torch.nn.Module):
"""Launch MLP in C++
Expand All @@ -33,16 +39,16 @@ class MLP(torch.nn.Module):
relu (bool): Default True
"""
def __init__(self, mlp_sizes, bias=True, activation='relu'):
super(MLP, self).__init__()
super().__init__()
self.num_layers = len(mlp_sizes) - 1
self.mlp_sizes = copy(mlp_sizes)
self.bias = 1 if bias else 0

if activation is 'none':
if activation == 'none':
self.activation = 0
elif activation is 'relu':
elif activation == 'relu':
self.activation = 1
elif activation is 'sigmoid':
elif activation == 'sigmoid':
self.activation = 2
else:
raise TypeError("activation must be relu or none.")
Expand Down
46 changes: 31 additions & 15 deletions tests/L0/run_mlp/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,7 @@ def test_numeric(self):
self.assertEqual(test_input.grad, ref_input.grad)
self.assertEqual(mlp.biases[0].grad, ref_mlp[0].bias.grad)

@common_utils.parametrize(
"use_activation,bias",
list(product(("none", "relu", "sigmoid"), (True, False))),
)
def test_mlp(self, use_activation: str, bias: bool):
# for use_activation in ["none", "relu", "sigmoid"]:
msg = f"activation: {use_activation}, bias: {bias}"
def _test_mlp_impl(self, use_activation: str, bias: bool, enable_autocast: bool):
mlp = MLP(mlp_sizes, bias=bias, activation=use_activation).cuda()

mlp_layers = []
Expand All @@ -81,15 +75,37 @@ def test_mlp(self, use_activation: str, bias: bool):
.requires_grad_()
)
ref_input = test_input.clone().detach().requires_grad_()
mlp_out = mlp(test_input)
ref_out = ref_mlp(ref_input)
self.assertEqual(mlp_out, ref_out, msg=msg)

# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
mlp_out.mean().mul(10.0).backward()
ref_out.mean().mul(10.0).backward()
self.assertEqual(test_input.grad, ref_input.grad, msg=msg)
self.assertEqual(mlp.weights[0].grad, ref_mlp[0].weight.grad, msg=msg)
with torch.cuda.amp.autocast_mode.autocast(enabled=enable_autocast):
mlp_out = mlp(test_input)
mlp_loss = mlp_out.mean().mul(10.0)
# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
ref_out = ref_mlp(ref_input)
ref_loss = ref_out.mean().mul(10.0)

mlp_loss.backward()
ref_loss.backward()
if enable_autocast:
self.assertEqual(mlp_out.dtype, torch.float16)
self.assertEqual(ref_out.dtype, torch.float16)
else:
self.assertEqual(mlp_out, ref_out)
self.assertEqual(test_input.grad, ref_input.grad)
self.assertEqual(mlp.weights[0].grad, ref_mlp[0].weight.grad)

@common_utils.parametrize(
"use_activation,bias",
list(product(("none", "relu", "sigmoid"), (True, False))),
)
def test_mlp(self, use_activation: str, bias: bool):
self._test_mlp_impl(use_activation, bias, enable_autocast=False)

@common_utils.parametrize(
"use_activation,bias",
list(product(("none", "relu", "sigmoid"), (True, False))),
)
def test_mlp_autocast_fp16(self, use_activation: str, bias: bool):
self._test_mlp_impl(use_activation, bias, enable_autocast=True)

def test_no_grad(self):
mlp = MLP(mlp_sizes).cuda()
Expand Down

0 comments on commit 81c9aca

Please sign in to comment.