-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodel.py
109 lines (91 loc) · 3.55 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import math
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss, SmoothL1Loss
from torch import nn
from transformers import (
AutoConfig,
BertConfig
)
# from transformers.modeling_bert import BertOnlyMLMHead
# from transformers.models.bert.modeling_bert import BertOnlyMLMHead
from transformers.models.roberta.modeling_roberta import RobertaForMaskedLM, RobertaLMHead
from bert_model import BertForMaskedLM, BertOnlyMLMHead
BertLayerNorm = torch.nn.LayerNorm
# The GLUE function is copied from huggingface transformers:
# https://github.com/huggingface/transformers/blob/c6acd246ec90857b70f449dcbcb1543f150821fc/src/transformers/activations.py
def _gelu_python(x):
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
if torch.__version__ < "1.4.0":
gelu = _gelu_python
else:
gelu = F.gelu
class SimpleBertForMaskedLM(BertForMaskedLM):
def __init__(self, config, args=None):
super().__init__(config, args=args)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
masked_lm_labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
lm_labels=None,
**kwargs
):
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
loss_fct = CrossEntropyLoss()
token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return {'loss': token_loss, 'lm_loss': token_loss}
class SimpleRobertaForMaskedLM(RobertaForMaskedLM):
def __init__(self, config, args=None):
super().__init__(config)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
masked_lm_labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
lm_labels=None,
**kwargs
):
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
loss_fct = CrossEntropyLoss()
token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return {'loss': token_loss, 'lm_loss': token_loss}