-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
57 lines (47 loc) · 1.87 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from datasets import semantickitti
from utils import common
from torch.utils.data import DataLoader
from modules import salsanext
dataset_helper = {
"semantickitti": semantickitti.SemanticKitti
}
model_helper = {
"salsanext": salsanext.SalsaNext
}
class Trainer:
def __init__(self, args):
self.config = common.read_yaml(args.config)
self._create_dataloader()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._create_model()
def _create_model(self):
self.model = model_helper[self.config["model"]](self.config, dataset_helper[self.config["dataset"]], self.device)
def _create_dataloader(self):
train_dataset = dataset_helper[self.config["dataset"]](self.config, "train")
self.train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=self.config["batch_size"],
shuffle=True,
drop_last=False,
num_workers=self.config["num_workers"]
)
val_dataset = dataset_helper[self.config["dataset"]](self.config, "val")
self.val_dataloader = DataLoader(
dataset=val_dataset,
batch_size=self.config["batch_size"],
shuffle=False,
drop_last=False,
num_workers=self.config["num_workers"]
)
def _run(self):
for epoch in range(self.config["epochs"]):
train_log = self.model.train(self.train_dataloader, self.writer)
if epoch % self.config["valid_every"]:
valid_log = self.model.valid(self.val_dataloader, self.writer)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="trainer script")
parser.add_argument("--config", required=True, type=str, help="path to config file")
args = parser.parse_args()
trainer = Trainer(args)
# trainer.run()