forked from PaddlePaddle/ERNIE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfinetune_classifier.py
268 lines (245 loc) · 9.83 KB
/
finetune_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import time
import logging
import json
from random import random
from functools import reduce, partial
from visualdl import LogWriter
import numpy as np
import logging
import argparse
from pathlib import Path
import paddle as P
from propeller import log
import propeller.paddle as propeller
log.setLevel(logging.DEBUG)
logging.getLogger().setLevel(logging.DEBUG)
#from model.bert import BertConfig, BertModelLayer
from ernie.modeling_ernie import ErnieModel, ErnieModelForSequenceClassification
from ernie.tokenizing_ernie import ErnieTokenizer, ErnieTinyTokenizer
#from ernie.optimization import AdamW, LinearDecay
from demo.utils import create_if_not_exists, get_warmup_and_linear_decay
parser = argparse.ArgumentParser('classify model with ERNIE')
parser.add_argument(
'--from_pretrained',
type=Path,
required=True,
help='pretrained model directory or tag')
parser.add_argument(
'--max_seqlen',
type=int,
default=128,
help='max sentence length, should not greater than 512')
parser.add_argument(
'--bsz',
type=int,
default=128,
help='global batch size for each optimizer step')
parser.add_argument(
'--micro_bsz',
type=int,
default=32,
help='batch size for each device. if `--bsz` > `--micro_bsz` * num_device, will do grad accumulate'
)
parser.add_argument('--epoch', type=int, default=3, help='epoch')
parser.add_argument(
'--data_dir',
type=str,
required=True,
help='data directory includes train / develop data')
parser.add_argument(
'--use_lr_decay',
action='store_true',
help='if set, learning rate will decay to zero at `max_steps`')
parser.add_argument(
'--warmup_proportion',
type=float,
default=0.1,
help='if use_lr_decay is set, '
'learning rate will raise to `lr` at `warmup_proportion` * `max_steps` and decay to 0. at `max_steps`'
)
parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
parser.add_argument(
'--inference_model_dir',
type=Path,
default=None,
help='inference model output directory')
parser.add_argument(
'--save_dir', type=Path, required=True, help='model output directory')
parser.add_argument(
'--max_steps',
type=int,
default=None,
help='max_train_steps, set this to EPOCH * NUM_SAMPLES / BATCH_SIZE')
parser.add_argument(
'--wd', type=float, default=0.01, help='weight decay, aka L2 regularizer')
parser.add_argument(
'--init_checkpoint',
type=str,
default=None,
help='checkpoint to warm start from')
parser.add_argument(
'--use_amp',
action='store_true',
help='only activate AMP(auto mixed precision accelatoin) on TensorCore compatible devices'
)
args = parser.parse_args()
if args.bsz > args.micro_bsz:
assert args.bsz % args.micro_bsz == 0, 'cannot perform gradient accumulate with bsz:%d micro_bsz:%d' % (
args.bsz, args.micro_bsz)
acc_step = args.bsz // args.micro_bsz
log.info(
'performing gradient accumulate: global_bsz:%d, micro_bsz:%d, accumulate_steps:%d'
% (args.bsz, args.micro_bsz, acc_step))
args.bsz = args.micro_bsz
else:
acc_step = 1
tokenizer = ErnieTokenizer.from_pretrained(args.from_pretrained)
#tokenizer = ErnieTinyTokenizer.from_pretrained(args.from_pretrained)
feature_column = propeller.data.FeatureColumns([
propeller.data.TextColumn(
'seg_a',
unk_id=tokenizer.unk_id,
vocab_dict=tokenizer.vocab,
tokenizer=tokenizer.tokenize),
propeller.data.TextColumn(
'seg_b',
unk_id=tokenizer.unk_id,
vocab_dict=tokenizer.vocab,
tokenizer=tokenizer.tokenize),
propeller.data.LabelColumn(
'label',
vocab_dict={
b"contradictory": 0,
b"contradiction": 0,
b"entailment": 1,
b"neutral": 2,
}),
])
def map_fn(seg_a, seg_b, label):
seg_a, seg_b = tokenizer.truncate(seg_a, seg_b, seqlen=args.max_seqlen)
sentence, segments = tokenizer.build_for_ernie(seg_a, seg_b)
return sentence, segments, label
train_ds = feature_column.build_dataset('train', data_dir=os.path.join(args.data_dir, 'train'), shuffle=True, repeat=False, use_gz=False) \
.map(map_fn) \
.padded_batch(args.bsz, (0, 0, 0))
dev_ds = feature_column.build_dataset('dev', data_dir=os.path.join(args.data_dir, 'dev'), shuffle=False, repeat=False, use_gz=False) \
.map(map_fn) \
.padded_batch(args.bsz, (0, 0, 0))
place = P.CUDAPlace(0)
model = ErnieModelForSequenceClassification.from_pretrained(
args.from_pretrained, num_labels=3, name='')
if args.init_checkpoint is not None:
log.info('loading checkpoint from %s' % args.init_checkpoint)
sd = P.load(str(args.init_checkpoint))
model.set_state_dict(sd)
g_clip = P.nn.ClipGradByGlobalNorm(1.0) #experimental
param_name_to_exclue_from_weight_decay = re.compile(
r'.*layer_norm_scale|.*layer_norm_bias|.*b_0')
if args.use_lr_decay:
lr_scheduler = P.optimizer.lr.LambdaDecay(
args.lr,
get_warmup_and_linear_decay(
args.max_steps, int(args.warmup_proportion * args.max_steps)))
opt = P.optimizer.AdamW(
lr_scheduler,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)
else:
lr_scheduler = None
opt = P.optimizer.AdamW(
args.lr,
parameters=model.parameters(),
weight_decay=args.wd,
apply_decay_param_fun=lambda n: not param_name_to_exclue_from_weight_decay.match(n),
grad_clip=g_clip)
scaler = P.amp.GradScaler(enable=args.use_amp)
step, inter_step = 0, 0
with LogWriter(
logdir=str(create_if_not_exists(args.save_dir / 'vdl'))) as log_writer:
with P.amp.auto_cast(enable=args.use_amp):
for epoch in range(args.epoch):
for ids, sids, label in P.io.DataLoader(
train_ds, places=P.CUDAPlace(0), batch_size=None):
inter_step += 1
loss, _ = model(ids, sids, labels=label)
loss /= acc_step
loss = scaler.scale(loss)
loss.backward()
if inter_step % acc_step != 0:
continue
step += 1
scaler.minimize(opt, loss)
model.clear_gradients()
lr_scheduler and lr_scheduler.step()
if step % 10 == 0:
_lr = lr_scheduler.get_lr(
) if args.use_lr_decay else args.lr
if args.use_amp:
_l = (loss / scaler._scale).numpy()
msg = '[step-%d] train loss %.5f lr %.3e scaling %.3e' % (
step, _l, _lr, scaler._scale.numpy())
else:
_l = loss.numpy()
msg = '[step-%d] train loss %.5f lr %.3e' % (step, _l,
_lr)
log.debug(msg)
log_writer.add_scalar('loss', _l, step=step)
log_writer.add_scalar('lr', _lr, step=step)
if step % 100 == 0:
acc = []
with P.no_grad():
model.eval()
for ids, sids, label in P.io.DataLoader(
dev_ds, places=P.CUDAPlace(0),
batch_size=None):
loss, logits = model(ids, sids, labels=label)
#print('\n'.join(map(str, logits.numpy().tolist())))
a = (logits.argmax(-1) == label)
acc.append(a.numpy())
model.train()
acc = np.concatenate(acc).mean()
log_writer.add_scalar('eval/acc', acc, step=step)
log.debug('acc %.5f' % acc)
if args.save_dir is not None:
P.save(model.state_dict(), str(args.save_dir / 'ckpt.bin'))
if args.save_dir is not None:
P.save(model.state_dict(),str( args.save_dir / 'ckpt.bin'))
if args.inference_model_dir is not None:
class InferenceModel(ErnieModelForSequenceClassification):
def forward(self, ids, sids):
_, logits = super(InferenceModel, self).forward(ids, sids)
return logits
model.__class__ = InferenceModel
log.debug('saving inference model')
src_placeholder = P.zeros([2, 2], dtype='int64')
sent_placehodler = P.zeros([2, 2], dtype='int64')
_, static = P.jit.TracedLayer.trace(
model, inputs=[src_placeholder, sent_placehodler])
static.save_inference_model(str(args.inference_model_dir))
#class InferenceModel(ErnieModelForSequenceClassification):
# @P.jit.to_static
# def forward(self, ids, sids):
# _, logits = super(InferenceModel, self).forward(ids, sids, labels=None)
# return logits
#model.__class__ = InferenceModel
#src_placeholder = P.zeros([2, 2], dtype='int64')
#sent_placehodler = P.zeros([2, 2], dtype='int64')
#P.jit.save(model, args.inference_model_dir, input_var=[src_placeholder, sent_placehodler])
log.debug('done')