forked from frotms/image_classification_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
73 lines (57 loc) · 2.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# coding=utf-8
import argparse
import textwrap
import time
import os, sys
sys.path.append(os.path.dirname(__file__))
from utils.config import process_config, check_config_dict
from utils.logger import ExampleLogger
from trainers.example_model import ExampleModel
from trainers.example_trainer import ExampleTrainer
from data_loader.dataset import get_data_loader
config = process_config(os.path.join(os.path.dirname(__file__), 'configs', 'config.json'))
class ImageClassificationPytorch:
def __init__(self, config):
gpu_id = config['gpu_id']
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
check_config_dict(config)
self.config = config
self.init()
def init(self):
# create net
self.model = ExampleModel(self.config)
# load
self.model.load()
# create your data generator
self.train_loader, self.test_loader = get_data_loader(self.config)
# create logger
self.logger = ExampleLogger(self.config)
self.logger.write_train_info_to_logger(variable_dict=self.config)
# self.logger.write()
# create trainer and path all previous components to it
self.trainer = ExampleTrainer(self.model, self.train_loader, self.test_loader, self.config, self.logger)
def run(self):
# here you train your model
self.trainer.train()
def close(self):
# close
self.logger.close()
def main():
imageClassificationPytorch = ImageClassificationPytorch(config)
imageClassificationPytorch.run()
imageClassificationPytorch.close()
if __name__ == '__main__':
now = time.strftime('%Y-%m-%d | %H:%M:%S', time.localtime(time.time()))
print('----------------------------------------------------------------------')
print('Time: ' + now)
print('----------------------------------------------------------------------')
print(' Now start ...')
print('----------------------------------------------------------------------')
main()
print('----------------------------------------------------------------------')
print(' All Done!')
print('----------------------------------------------------------------------')
print('Start time: ' + now)
print('Now time: ' + time.strftime('%Y-%m-%d | %H:%M:%S', time.localtime(time.time())))
print('----------------------------------------------------------------------')