-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
113 lines (87 loc) · 3.66 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import torch
from src.load_data import data_builder
from src.models import *
def save_checkpoint(model, model_name, cur_epoch, is_best=False):
param_grad_dic = {
k: v.requires_grad for (k, v) in model.named_parameters()
}
state_dict = model.state_dict()
for k in list(state_dict.keys()):
if k in param_grad_dic.keys() and not param_grad_dic[k]:
# delete parameters that do not require gradient
del state_dict[k]
save_obj = {"model": state_dict,"epoch": cur_epoch}
os.system ("rm trained_models/%s*"%model_name)
save_to = "trained_models/%s_%s.pth"%(model_name, ("best" if is_best else str (cur_epoch)))
print("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
torch.save(save_obj, save_to)
def load_from_checkpoint(model, checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location="cuda")
try:
model.load_state_dict(checkpoint["model"])
except RuntimeError as e:
model.load_state_dict(checkpoint["model"], strict=False)
return model
def train (model, model_name, data_loader, epochs = 100, save_iters = 10, starting_epoch = 1):
model.train()
optim = Adam(model.parameters(), lr=0.0001)
best_loss = 100000
for epoch in range(starting_epoch, epochs + 1):
print ('-------- Epoch: ', epoch)
mean_loss = 0
if epoch > (epochs - 20):
save_iters = 2
for sample in data_loader:
loss = model(sample)
mean_loss += loss
optim.zero_grad()
loss.backward()
optim.step()
print (mean_loss / len (data_loader))
if epoch % save_iters == 0 and mean_loss < best_loss:
best_loss = mean_loss
save_checkpoint(model, model_name, epoch)
def test (model, data_loader, model_name):
model.eval()
f = open("results/%s.txt"%args.model_name, "w")
for sample in data_loader:
output_text = model.generate (sample)
for predicted, target in zip (output_text, sample["text_output"]):
f.write("The predicted Conversation :")
f.write(predicted + "\n")
f.write("The target Conversation :")
f.write(target + "\n")
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default = 32, type = int)
parser.add_argument("-seed", default = 3)
parser.add_argument("--model_name", "-m", help="Name of the model to train.", choices = ["MllmBrainToTextV0", "MllmBrainToTextV1", "MllmBrainToTextV2"])
parser.add_argument('--test', action='store_true', help = "test the model")
parser.add_argument('--retrain', action='store_true', help = "retrain from existing checkpoint")
parser.add_argument("--starting_epoch", default = 1, type = int)
parser.add_argument("--save_epochs", default = 5, type = int)
parser.add_argument("--epochs", default = 300, type = int)
parser.add_argument("--saved_checkpoint", type = str)
args = parser.parse_args()
models_dict = {
'MllmBrainToTextV0':MllmBrainToTextV0,
'MllmBrainToTextV1':MllmBrainToText,
'MllmBrainToTextV2':MllmBrainToTextV2,
}
torch.manual_seed(args.seed)
data_loader = data_builder(args.batch_size)
llm = models_dict[args.model_name]()
if args.test:
llm = load_from_checkpoint(llm, args.saved_checkpoint)
test (llm, data_loader["test"], args.model_name)
else:
if args.retrain:
llm = load_from_checkpoint(llm, args.saved_checkpoint)
train (llm,
args.model_name,
data_loader["train"],
epochs = args.epochs,
save_iters = args.save_iters,
starting_epoch = args.starting_epoch)