-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtrain_paddle.py
46 lines (40 loc) · 1.68 KB
/
train_paddle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
from paddlex.utils.result_saver import try_except_decorator
from paddlex.utils.config import parse_args, get_config
from paddlex.utils.errors import raise_unsupported_api_error
from paddlex.model import _ModelBasedConfig
print(f"""数据集格式如下:
dataset
├─images #所有图片存放路径
├─label.txt #标签路径,每一行数据格式为 <序号>+<空格>+<类别>,如15 地球仪
├─train.txt #训练图片,每一行数据格式为 <图片路径>+<空格>+<类别>,如images/001.jpg 0
└─验证集和测试集同上
""")
class Engine(object):
"""Engine"""
def __init__(self):
args = parse_args()
args.config='PP-HGNetV2-B4.yaml'
args.override=['Global.mode=train', 'Global.dataset_dir=dataset']
config = get_config(args.config, overrides=args.override, show=False)
self._mode = config.Global.mode
self._output = config.Global.output
self._model = _ModelBasedConfig(config)
@try_except_decorator
def run(self):
"""the main function"""
if self._mode == "check_dataset":
return self._model.check_dataset()
elif self._mode == "train":
self._model.train()
elif self._mode == "evaluate":
return self._model.evaluate()
elif self._mode == "export":
return self._model.export()
elif self._mode == "predict":
for res in self._model.predict():
res.print(json_format=False)
else:
raise_unsupported_api_error(f"{self._mode}", self.__class__)
if __name__ == "__main__":
Engine().run()