Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add async nan check utils #965

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/utils/test_nan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch

from torchtnt.utils.nan import check_for_nan_or_inf, register_nan_hooks_on_whole_graph


class NaNFunction(torch.autograd.Function):
@staticmethod
# pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently
def forward(ctx, input):
return input.clone()

@staticmethod
# pyre-ignore overrides method defined in `torch.autograd.function._SingleLevelFunction` inconsistently
def backward(ctx, grad_output):
return torch.tensor([float("nan")], device="cpu")


class NanHookTest(unittest.TestCase):
def test_register_nan_hooks_on_whole_graph(self) -> None:
x = torch.tensor([1.0], device="cpu", requires_grad=True)
out = NaNFunction.apply(x)

# no error is thrown
out.backward()

_ = register_nan_hooks_on_whole_graph([out])
with self.assertRaisesRegex(RuntimeError, "Detected NaN"):
out.backward()

def test_check_for_nan_or_inf(self) -> None:
tensor = torch.tensor([float("nan")], device="cpu")

with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"):
check_for_nan_or_inf(tensor)

tensor = torch.tensor([float("inf")], device="cpu")
with self.assertRaisesRegex(RuntimeError, "Detected NaN or Inf in tensor"):
check_for_nan_or_inf(tensor)
3 changes: 3 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ModuleSummary,
prune_module_summary,
)
from .nan import check_for_nan_or_inf, register_nan_hooks_on_whole_graph
from .oom import (
attach_oom_observer,
is_out_of_cpu_memory,
Expand Down Expand Up @@ -89,6 +90,8 @@
)

__all__ = [
"check_for_nan_or_inf",
"register_nan_hooks_on_whole_graph",
"IsNaNEvaluator",
"ThresholdEvaluator",
"CheckpointPath",
Expand Down
107 changes: 107 additions & 0 deletions torchtnt/utils/nan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import deque
from typing import Callable, Iterator, List, Optional, Sequence, Union

import torch
from pyre_extensions import none_throws
from torch.autograd.graph import GradientEdge, Node
from torch.utils.hooks import RemovableHandle


def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, GradientEdge]) -> Node:
if isinstance(t, torch.Tensor):
return none_throws(t.grad_fn)
else:
# pyre-ignore Undefined attribute [16]: `GradientEdge` has no attribute `function`.
return t.function if t is not None else None


def register_nan_hooks_on_whole_graph( # noqa: C901
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]]
) -> Callable[[], None]:
"""
Registers a nan hook on the whole graph of the given tensors. The hook will throw error if a nan is detected.

This is useful if you want training to halt when a nan is detected during autograd process (ie loss is inf or nan).

Usage:

>>> class NaNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input.clone()

@staticmethod
def backward(ctx, grad_output):
return torch.tensor([float("nan")], device="cpu")
>>> x = torch.tensor([1.0], device="cpu", requires_grad=True)
>>> out = NaNFunction.apply(x)
>>> _ = register_nan_hooks_on_whole_graph([out])
>>> out.backward()
RuntimeError: Detected NaN in 'grad_inputs[0]' after executing Node

"""

grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))

def iter_graph(roots: List[torch.autograd.graph.Node]) -> Iterator[Node]:
if not roots:
return
seen = set()
q = deque()
for node in roots:
if node is not None and node not in seen:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _ in node.next_functions:
if fn is None or fn in seen:
continue
seen.add(fn)
q.append(fn)
yield node

def _assert_no_nan_tensor(t: Optional[torch.Tensor], msg: str) -> None:
if t is not None:
torch._assert_async(torch.logical_not(torch.any(torch.isnan(t))), msg)

def posthook(
grad_inputs: Sequence[Optional[torch.Tensor]],
grad_outputs: Sequence[Optional[torch.Tensor]],
) -> None:
node = torch._C._current_autograd_node()
for i, g_in in enumerate(grad_inputs):
_assert_no_nan_tensor(
g_in, f"Detected NaN in 'grad_inputs[{i}]' after executing Node: {node}"
)

handles: List[RemovableHandle] = []
for node in iter_graph(grad_fns):
posthandle = node.register_hook(posthook)
handles.append(posthandle)

def unregister_hooks() -> None:
for handle in handles:
handle.remove()

return unregister_hooks


def check_for_nan_or_inf(
tensor: torch.Tensor, msg: str = "Detected NaN or Inf in tensor"
) -> None:
"""
Asynchronously assert that the tensor is neither NaN nor infinity. This will
produce a cuda device side assert error if tensor on gpu.
"""

torch._assert_async(
torch.logical_not(torch.any(torch.isnan(tensor) | torch.isinf(tensor))),
msg,
)
Loading