-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModelLoader.py
155 lines (137 loc) · 6.04 KB
/
ModelLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import torch
from torch.utils.data import DataLoader, random_split
class ModelLoader:
def __init__(self, train_dataset, test_dataset, batch_size, model_path: str, if_early_stop=False, debug_mode=False):
"""
初始化模型加载器。
Args:
train_dataset: 用于训练的数据集。
test_dataset: 用于测试的数据集。
batch_size: 每批数据的大小。
model_path: 保存或加载模型权重的路径。
if_early_stop: 是否早停。
debug_mode: 是否为调试模式,调试模式下可能会启用额外的日志或检查点。
"""
self.predict_mode = False
if train_dataset is None and test_dataset is None:
print('Model will run in predict mode.')
self.predict_mode = True
elif train_dataset is not None and test_dataset is not None:
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.batch_size = batch_size
self.train_iterator = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
self.test_iterator = DataLoader(self.test_dataset, batch_size=self.batch_size)
elif train_dataset is not None and test_dataset is None:
print('Generate test dataset randomly.')
# 如果没有提供测试集,则从训练集中随机选择一部分作为测试集
test_ratio = 0.1
train_size = int(len(train_dataset) * (1 - test_ratio))
test_size = len(train_dataset) - train_size
self.batch_size = batch_size
self.train_dataset, self.test_dataset = random_split(train_dataset, [train_size, test_size])
self.train_iterator = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
self.test_iterator = DataLoader(self.test_dataset, batch_size=self.batch_size)
self.if_early_stop = if_early_stop
self.debug_mode = debug_mode
if debug_mode:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.autograd.set_detect_anomaly(True)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_path = model_path
self.model = None
self.lr = None
self.optimizer = None
self.scheduler = None
self.best_loss = float('inf')
def _train_epoch(self):
"""
训练模型一个epoch,子类需要根据具体的模型实现该方法。
"""
raise NotImplementedError('Subclasses should implement this method.')
def _test_epoch(self):
"""
测试模型一个epoch,子类需要根据具体的模型实现该方法。
"""
raise NotImplementedError('Subclasses should implement this method.')
def train(self, epochs=50, test_interval=10):
"""
训练模型,周期性地在测试集上评估性能。
Args:
epochs: 训练的总轮次。
test_interval: 测试间隔。
"""
if self.predict_mode:
print('No data given, model is running in predict mode.')
return
print('Start training...')
if self.if_early_stop:
# 早停策略防止过拟合
best_test_loss = float('inf')
patience_counter = 0
for epoch in range(1, epochs + 1):
train_loss = self._train_epoch()
test_loss = self._test_epoch()
print(f'Epoch {epoch}/{epochs} - Train Loss: {train_loss:.2f}, Test Loss: {test_loss:.2f}')
if self.scheduler:
self.scheduler.step()
self.save_model(test_loss)
if epoch % test_interval == 0:
self.test()
if self.if_early_stop:
# 计算早停累计
if best_test_loss > test_loss:
best_test_loss = test_loss
patience_counter = 1
else:
patience_counter += 1
if patience_counter > max(epochs / 5, 10):
# 早停
print('Training interrupted to avoid overfitting.')
break
def test(self):
raise NotImplementedError('Subclasses should implement this method.')
def save_model(self, loss=float('inf')):
"""
保存模型的权重。
"""
if loss < self.best_loss:
self.best_loss = loss
model_name, model_extension = os.path.splitext(self.model_path)
best_model_path = f"{model_name}_best{model_extension}"
torch.save(self.model.state_dict(), best_model_path)
print(f'Model saved to {best_model_path}')
torch.save(self.model.state_dict(), self.model_path)
print(f'Model saved to {self.model_path}')
def load_model(self):
"""
加载模型的权重。
"""
# load model weight
model_dir = os.path.dirname(self.model_path)
print('Try to load model from', self.model_path)
# 检查模型文件夹路径是否存在
if not os.path.exists(model_dir):
# 不存在就创建新的目录
os.makedirs(model_dir)
print(f"Created directory '{model_dir}' for saving models.")
if os.path.isfile(self.model_path):
try:
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
print("Model loaded successfully from '{}'".format(self.model_path))
except Exception as e:
print("Failed to load model. Starting from scratch. Error: ", e)
else:
print("No saved model found at '{}'. Starting from scratch.".format(self.model_path))
def run(self, train_epochs=50):
"""
执行训练和测试周期。
"""
try:
self.train(train_epochs)
self.test()
except KeyboardInterrupt:
print('Training interrupted by the user.')
finally:
self.save_model()