Skip to content

Commit

Permalink
Deprecate outdated features (NVIDIA#1506)
Browse files Browse the repository at this point in the history
* add warnings

update README

verify deprecated warning

deprecate apex.RNN which is never a thing

Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar authored Oct 10, 2022
1 parent b213772 commit 48bf0df
Show file tree
Hide file tree
Showing 14 changed files with 183 additions and 74 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The intent of Apex is to make up-to-date utilities available to users as quickly

## 1. Amp: Automatic Mixed Precision

**Deprecated. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)**

`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to `amp.initialize`.
Expand All @@ -29,6 +31,8 @@ different flags to `amp.initialize`.

## 2. Distributed Training

**`apex.parallel.DistributedDataParallel` is deprecated. Use [`torch.nn.parallel.DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=distributeddataparallel#torch.nn.parallel.DistributedDataParallel)**

`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
Expand All @@ -44,6 +48,8 @@ shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.

### Synchronized Batch Normalization

**Deprecated. Use [`torch.nn.SyncBatchNorm`](https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html)**

`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Expand Down
2 changes: 2 additions & 0 deletions apex/RNN/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
**This module will be removed by the end of February 2023**

Under construction...
4 changes: 3 additions & 1 deletion apex/RNN/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell

from apex import deprecated_warning
from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell
from .cells import mLSTMRNNCell, mLSTMCell

Expand All @@ -10,6 +11,7 @@ def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0):
:class:`toRNNBackend`
"""

deprecated_warning("`apex.RNN` is deprecated and will be removed by the end of February 2023.")
if bidirectional:
return bidirectionalRNN(inputRNN, num_layers, dropout = dropout)
else:
Expand Down Expand Up @@ -43,7 +45,7 @@ def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, drop
"""
inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size)
return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)

def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
"""
:class:`mLSTM`
Expand Down
17 changes: 17 additions & 0 deletions apex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import torch


__all__ = ["amp", "fp16_utils", "optimizers", "normalization", "transformer"]


if torch.distributed.is_available():
from . import parallel
__all__.append("parallel")

from . import amp
from . import fp16_utils
Expand Down Expand Up @@ -49,3 +53,16 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
)
return False
return True


class DeprecatedFeatureWarning(FutureWarning):
pass


def deprecated_warning(msg: str) -> None:
if (
not torch.distributed.is_available
or not torch.distributed.is_initialized()
or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)
):
warnings.warn(msg, DeprecatedFeatureWarning)
16 changes: 11 additions & 5 deletions apex/amp/amp.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import functools
import itertools

import torch

from . import compat, rnn_compat, utils, wrap
from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides
from ._amp_state import _amp_state
from .frontend import *

import functools
import itertools

import torch


_DECORATOR_HANDLE = None
_USER_CAST_REGISTRY = set()
Expand All @@ -28,16 +28,22 @@ def wrapper(*args, **kwargs):

# Decorator form
def half_function(fn):
from apex import deprecated_warning
deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
return _decorator_helper(fn, utils.maybe_half, wrap_fn)


def float_function(fn):
from apex import deprecated_warning
deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)


def promote_function(fn):
from apex import deprecated_warning
deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
wrap_fn = functools.partial(wrap.make_promote_wrapper)
return _decorator_helper(fn, utils.maybe_float, wrap_fn)

Expand Down
8 changes: 6 additions & 2 deletions apex/amp/frontend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections import OrderedDict

import torch

from ._initialize import _initialize
from ._amp_state import _amp_state, warn_or_err, maybe_print
from collections import OrderedDict


class Properties(object):
Expand Down Expand Up @@ -305,6 +307,8 @@ def initialize(
.. _`let us know`:
https://github.com/NVIDIA/apex/issues
"""
from apex import deprecated_warning
deprecated_warning("apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)")
_amp_state.opt_properties = Properties()
_amp_state.verbosity = verbosity

Expand Down Expand Up @@ -377,7 +381,7 @@ def load_state_dict(state_dict):
len(state_dict), len(_amp_state.loss_scalers)))

state_dict = state_dict.copy()

nb_loss_scalers = len(_amp_state.loss_scalers)
unexpected_keys = []
# Initialize idx outside, since unexpected_keys will increase it if enumerate is used
Expand Down
Loading

0 comments on commit 48bf0df

Please sign in to comment.