-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathmodel.py
58 lines (52 loc) · 2.45 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import keras
from keras_self_attention import Attention
from keras_contrib.layers import CRF
def build_model(token_num,
tag_num,
embedding_dim=100,
embedding_weights=None,
rnn_units=100,
return_attention=False,
lr=1e-3):
"""Build the model for predicting tags.
:param token_num: Number of tokens in the word dictionary.
:param tag_num: Number of tags.
:param embedding_dim: The output dimension of the embedding layer.
:param embedding_weights: Initial weights for embedding layer.
:param rnn_units: The number of RNN units in a single direction.
:param return_attention: Whether to return the attention matrix.
:param lr: Learning rate of optimizer.
:return model: The built model.
"""
if embedding_weights is not None and not isinstance(embedding_weights, list):
embedding_weights = [embedding_weights]
input_layer = keras.layers.Input(shape=(None,))
embd_layer = keras.layers.Embedding(input_dim=token_num,
output_dim=embedding_dim,
mask_zero=True,
weights=embedding_weights,
trainable=embedding_weights is None,
name='Embedding')(input_layer)
lstm_layer = keras.layers.Bidirectional(keras.layers.LSTM(units=rnn_units,
recurrent_dropout=0.4,
return_sequences=True),
name='Bi-LSTM')(embd_layer)
attention_layer = Attention(attention_activation='sigmoid',
attention_width=9,
return_attention=return_attention,
name='Attention')(lstm_layer)
if return_attention:
attention_layer, attention = attention_layer
crf = CRF(units=tag_num, sparse_target=True, name='CRF')
outputs = [crf(attention_layer)]
loss = {'CRF': crf.loss_function}
if return_attention:
outputs.append(attention)
loss['Attention'] = Attention.loss(1e-4)
model = keras.models.Model(inputs=input_layer, outputs=outputs)
model.compile(
optimizer=keras.optimizers.Adam(lr=lr),
loss=loss,
metrics={'CRF': crf.accuracy},
)
return model