Skip to content

Commit

Permalink
use torch.tensor to create a tensor with initializer values (NVIDIA…
Browse files Browse the repository at this point in the history
…#1588)

* use `torch.tensor` with init values

Signed-off-by: Masaki Kozuki <[email protected]>

* Update apex/contrib/sparsity/sparse_masklib.py

* remove torch._six

Signed-off-by: Masaki Kozuki <[email protected]>

* retire `torch._six`

as per the upstream commit of `b005ec62b9`.

Signed-off-by: Masaki Kozuki <[email protected]>

* use std collections.abc

Signed-off-by: Masaki Kozuki <[email protected]>

---------

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar authored Feb 16, 2023
1 parent 93bb36a commit 6943fd2
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 21 deletions.
10 changes: 0 additions & 10 deletions apex/amp/_amp_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,8 @@
# I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like.
# But apparently it's ok:
# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
import os
import torch

TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])


if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
from torch._six import container_abcs
else:
import collections.abc as container_abcs


class AmpState(object):
def __init__(self):
Expand Down
14 changes: 8 additions & 6 deletions apex/amp/_initialize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import torch
from torch._six import string_classes
import collections.abc as container_abcs
from types import MethodType
import functools
import numpy as np
import sys
from types import MethodType
import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs

import numpy as np
import torch

from ._amp_state import _amp_state, warn_or_err
from .handle import disable_casts
from .scaler import LossScaler
from ._process_optimizer import _process_optimizer
Expand Down Expand Up @@ -39,7 +41,7 @@ def to_type(dtype, t):
def applier(value, fn):
if isinstance(value, torch.Tensor):
return fn(value)
elif isinstance(value, string_classes):
elif isinstance(value, str):
return value
elif isinstance(value, np.ndarray):
return value
Expand Down
7 changes: 4 additions & 3 deletions apex/contrib/clip_grad/clip_grad.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import torch
from torch._six import inf
from typing import Union, Iterable

import torch

_kernel_import_succeeded = False
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
_kernel_import_succeeded = True
except:
except ImportError:
_kernel_import_succeeded = False

_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]


def clip_grad_norm_(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
error_if_nonfinite: bool = False) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions apex/contrib/sparsity/sparse_masklib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def compute_valid_1d_patterns(m,n):
if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns
patterns = torch.zeros(m)
patterns[:n] = 1
valid_patterns = torch.empty(list(set(permutations(patterns.tolist()))))
if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns
valid_patterns = torch.tensor(list(set(permutations(patterns.tolist()))))
if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns
return valid_patterns

""" m:n 1d structured best """
Expand Down

0 comments on commit 6943fd2

Please sign in to comment.