Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: pytorch/tnt
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 4ca7a4c2e1dc1790f2640134ad3e0e7939fb3eae
Choose a base ref
..
head repository: pytorch/tnt
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: d9cbacb569a342cfdc83ca4a0dea26dfe9ad5201
Choose a head ref
Showing with 10 additions and 10 deletions.
  1. +10 −10 tests/utils/test_prepare_module.py
20 changes: 10 additions & 10 deletions tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
@@ -22,10 +22,12 @@
TorchCompileParams,
)
from torchtnt.utils.test_utils import skip_if_not_distributed
from torchtnt.utils.version import Version
from torchtnt.utils.version import is_torch_version_geq


class PrepareModelTest(unittest.TestCase):
torch_version_geq_2_1_0: bool = is_torch_version_geq("2.1.0")

def test_invalid_fsdp_strategy_str_values(self) -> None:
from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision

@@ -143,7 +145,9 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
"""

tc = unittest.TestCase()
with patch("torchtnt.utils.version.is_torch_version_geq", return_value=False):
with patch(
"torchtnt.utils.prepare_module.is_torch_version_geq", return_value=False
):
with tc.assertRaisesRegex(
RuntimeError,
"Torch version >= 2.1.0 required",
@@ -155,14 +159,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
torch_compile_params=TorchCompileParams(backend="inductor"),
)

# no error should be thrown on latest pytorch
prepare_module(
module=torch.nn.Linear(2, 2),
device=init_from_env(),
strategy=DDPStrategy(static_graph=True),
torch_compile_params=TorchCompileParams(backend="inductor"),
)

def test_prepare_module_compile_invalid_backend(self) -> None:
"""
verify error is thrown on invalid backend
@@ -188,6 +184,10 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
torch_compile_params=TorchCompileParams(),
)

@unittest.skipUnless(
torch_version_geq_2_1_0,
reason="Must be on torch 2.1.0+ to run test",
)
def test_prepare_module_compile_module_state_dict(self) -> None:
device = init_from_env()
my_module = torch.nn.Linear(2, 2, device=device)