diff --git a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py b/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py index a1193f1c6..3725c3eec 100644 --- a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py +++ b/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py @@ -2,6 +2,7 @@ import logging import itertools import os +from datetime import datetime from packaging.version import parse, Version import re from typing import Optional, Tuple, List @@ -40,11 +41,19 @@ weight_coeff = 1024 # Guard for https://github.com/pytorch/pytorch/pull/82450 +def get_nvidia_pytorch_version(): + ver = os.getenv("NVIDIA_PYTORCH_VERSION", "22.08") + if "master" in ver: + ver = datetime.today().strftime("%y.%m") + elif "update_for_" in ver: + ver = ver.replace("update_for_", "") + return ver + CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV = False ngc_container_2209, pytorch_113 = Version("22.09"), Version("1.13") -if parse(os.getenv("NVIDIA_PYTORCH_VERSION", "22.08")) >= ngc_container_2209: +if parse(torch.__version__) >= pytorch_113: CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV = True -elif parse(torch.__version__) >= pytorch_113: +elif parse(get_nvidia_pytorch_version()) >= ngc_container_2209: CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV = True else: CAN_SKIP_SYNC_AFTER_BATCH_ISEND_IRECV = False