-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
120 lines (98 loc) · 4.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import datetime
import random
import wandb
import hydra
from pathlib import Path
import copy
from typing import List, Union
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf
from sejong_music import yeominrak_processing, model_zoo, loss, utils, jg_code, trainer as trainer_zoo
from sejong_music.yeominrak_processing import pack_collate
from sejong_music.model_zoo import get_emb_total_size
from sejong_music.loss import nll_loss, focal_loss
from sejong_music.trainer import JeongganTrainer, BertTrainer
from sejong_music.train_utils import CosineLRScheduler
# from sejong_music.constants import PART, POSITION, PITCH
from sejong_music.jg_code import JeongganDataset, JGMaskedDataset, ABCDataset
from sejong_music.full_inference import Generator
def make_experiment_name_with_date(config):
current_time_in_str = datetime.datetime.now().strftime("%m%d-%H%M")
return f'{current_time_in_str}_data={config.dataset_class}_depth={config.model.depth}_head={config.model.num_heads}_drop={config.model.dropout}_is_beat={config.data.use_offset}_pos_count={config.data.is_pos_counter}_{config.general.exp_name}_{config.model_class}'
@hydra.main(config_path='yamls/', config_name='transformer_jeonggan')
def main(config: DictConfig):
config = get_emb_total_size(config)
if config.general.make_log:
wandb.init(
project=config.general.project,
entity=config.general.entity,
name = make_experiment_name_with_date(config),
config = OmegaConf.to_container(config),
dir=Path(hydra.utils.get_original_cwd())
)
save_dir = Path(wandb.run.dir) / 'checkpoints'
else:
save_dir = Path('wandb/debug/checkpoints')
original_wd = Path(hydra.utils.get_original_cwd())
if not save_dir.is_absolute():
save_dir = original_wd / save_dir
save_dir.mkdir(parents=True, exist_ok=True)
dataset_class = getattr(jg_code, config.dataset_class)
model_class = getattr(model_zoo, config.model_class)
dataset_class:Union[JeongganDataset, JGMaskedDataset] = getattr(jg_code, config.dataset_class)
trainer_class:Union[JeongganTrainer, BertTrainer] = getattr(trainer_zoo, config.trainer_class)
train_dataset = dataset_class(data_path= original_wd / 'music_score/jg_dataset',
slice_measure_num = config.data.slice_measure_num,
is_valid=False,
augment_param = config.aug,
num_max_inst = config.data.num_max_inst
)
val_dataset = dataset_class(data_path= original_wd / 'music_score/jg_dataset',
is_valid=True,
augment_param = config.aug,
num_max_inst = config.data.num_max_inst
)
collate_fn = getattr(utils, config.collate_fn)
loss_fn = getattr(loss, config.loss_fn)
train_loader = DataLoader(train_dataset,
batch_size=config.train.batch_size ,
shuffle=True,
collate_fn=collate_fn,
num_workers=4)
valid_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False, collate_fn=collate_fn, drop_last=True)
device = 'cuda'
# --- Save config and tokenizer --- #
with open(save_dir / 'config.yaml', 'w') as f:
OmegaConf.save(config, f)
tokenizer_vocab_path = save_dir / 'tokenizer_vocab.json'
train_dataset.tokenizer.save_to_json(tokenizer_vocab_path)
model = model_class(train_dataset.tokenizer, config.model).to(device)
num_model_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of model parameters: {num_model_parameters}')
if config.general.make_log:
wandb.run.summary['num_model_parameters'] = num_model_parameters
optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr)
scheduler = CosineLRScheduler(optimizer, total_steps=config.train.num_epoch * len(train_loader), warmup_steps=1000, lr_min_ratio=0.001, cycle_length=1.0)
# --- Training --- #
atrainer = trainer_class(model=model,
optimizer=optimizer,
loss_fn=loss_fn,
train_loader=train_loader,
valid_loader=valid_loader,
device = device,
save_log=config.general.make_log,
save_dir=save_dir,
scheduler=scheduler,
min_epoch_for_infer=100)
# generator = Generator(config=None,
# model=model,
# output_dir=save_dir,
# inferencer=atrainer.inferencer,
# is_abc = dataset_class==ABCDataset,
# )
atrainer.train_by_num_iteration(config.train.num_epoch * len(train_loader))
atrainer.load_best_model()
print(atrainer.make_inference_result())
if __name__ == '__main__':
main()