Skip to content

Commit

Permalink
check decoder_inputs_embeds is None before shifting labels (huggingfa…
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker authored Oct 18, 2022
1 parent d356b89 commit 3e07196
Show file tree
Hide file tree
Showing 21 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ def call(
labels,
)
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,7 @@ def forward(
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ def call(
labels,
)
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ def forward(
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ def call(
labels,
)
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2428,7 +2428,7 @@ def forward(
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/led/modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2445,7 +2445,7 @@ def call(

if labels is not None:
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,7 @@ def forward(
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/marian/modeling_tf_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ def call(
labels,
)
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,7 @@ def forward(
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

outputs = self.model(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/mbart/modeling_tf_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,7 @@ def call(
labels,
)
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

outputs = self.model(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,7 +1393,7 @@ def forward(
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/pegasus/modeling_tf_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,7 @@ def call(
labels,
)
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/pegasus_x/modeling_pegasus_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ def forward(
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/plbart/modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)

outputs = self.model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,7 @@ def call(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/whisper/modeling_tf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,7 @@ def call(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if labels is not None:
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2938,7 +2938,7 @@ def call(

if labels is not None:
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2833,7 +2833,7 @@ def forward(
if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False
if decoder_input_ids is None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

outputs = self.model(
Expand Down

0 comments on commit 3e07196

Please sign in to comment.