From b29dab67ada2e00bc755546a9d5a8aa55bc95da7 Mon Sep 17 00:00:00 2001 From: Daniel Khashabi Date: Mon, 12 Oct 2020 12:56:51 -0700 Subject: [PATCH] Create bart.py --- bart_example_solver/bart.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 bart_example_solver/bart.py diff --git a/bart_example_solver/bart.py b/bart_example_solver/bart.py new file mode 100644 index 0000000..9be6079 --- /dev/null +++ b/bart_example_solver/bart.py @@ -0,0 +1,36 @@ + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from transformers import T5ForConditionalGeneration, BartForConditionalGeneration + +class MyBart(BartForConditionalGeneration): + def forward(self, input_ids, attention_mask=None, encoder_outputs=None, + decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None, + use_cache=False, is_training=False): + + if is_training: + decoder_start_token_id = self.config.decoder_start_token_id + _decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape) + _decoder_input_ids[..., 1:] = decoder_input_ids[..., :-1].clone() + _decoder_input_ids[..., 0] = decoder_start_token_id + else: + _decoder_input_ids = decoder_input_ids.clone() + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + encoder_outputs=encoder_outputs, + decoder_input_ids=_decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_cached_states=decoder_cached_states, + use_cache=use_cache, + ) + lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) + if is_training: + loss_fct = nn.CrossEntropyLoss(reduce=False) + losses = loss_fct(lm_logits.view(-1, self.config.vocab_size), + decoder_input_ids.view(-1)) + loss = torch.sum(losses * decoder_attention_mask.float().view(-1)) + return loss + return (lm_logits, ) + outputs[1:]