diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 368920f3769c6f..7cc3bef70fdc89 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -36,6 +36,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -55,6 +56,9 @@ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b9625dd9213989..dd38abfcd01cb3 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -32,7 +32,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -52,6 +52,9 @@ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index d0bc55fe8383ff..1a464b62a66513 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -30,6 +30,7 @@ is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1") is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") +is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11") is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")