Skip to content

Commit

Permalink
Create bart.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Khashabi authored Oct 12, 2020
1 parent 64b4ee9 commit b29dab6
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions bart_example_solver/bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from transformers import T5ForConditionalGeneration, BartForConditionalGeneration

class MyBart(BartForConditionalGeneration):
def forward(self, input_ids, attention_mask=None, encoder_outputs=None,
decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None,
use_cache=False, is_training=False):

if is_training:
decoder_start_token_id = self.config.decoder_start_token_id
_decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape)
_decoder_input_ids[..., 1:] = decoder_input_ids[..., :-1].clone()
_decoder_input_ids[..., 0] = decoder_start_token_id
else:
_decoder_input_ids = decoder_input_ids.clone()

outputs = self.model(
input_ids,
attention_mask=attention_mask,
encoder_outputs=encoder_outputs,
decoder_input_ids=_decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
use_cache=use_cache,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
if is_training:
loss_fct = nn.CrossEntropyLoss(reduce=False)
losses = loss_fct(lm_logits.view(-1, self.config.vocab_size),
decoder_input_ids.view(-1))
loss = torch.sum(losses * decoder_attention_mask.float().view(-1))
return loss
return (lm_logits, ) + outputs[1:]

0 comments on commit b29dab6

Please sign in to comment.