-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
132 lines (106 loc) · 4.75 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os, json, sys, argparse, random, datetime, wandb
import torch, torch_geometric
import numpy as np
from rdkit import Chem
from torch_geometric.data import Data, Batch, DataLoader
import hydragnn
from hydragnn.utils.print import setup_log
from hydragnn.utils.input_config_parsing import config_utils
from hydragnn.utils.distributed import setup_ddp, get_distributed_model
from hydragnn.preprocess.graph_samples_checks_and_updates import update_predicted_values
from src.utils import diffusion_utils as du
from src.processes.diffusion import DiffusionProcess
from src.processes.equivariant_diffusion import EquivariantDiffusionProcess
from src.processes.marginal_diffusion import MarginalDiffusionProcess
from src.utils.train_utils import train_model, get_train_transform, insert_t
def train(args):
# Set this path for output.
try:
os.environ["SERIALIZED_DATA_PATH"]
except:
os.environ["SERIALIZED_DATA_PATH"] = os.getcwd()
# Configurable run choices (JSON file that accompanies this example script).
with open(args.config_path, "r") as f:
config = json.load(f)
verbosity = config["Verbosity"]["level"]
var_config = config["NeuralNetwork"]["Variables_of_interest"]
# Always initialize for multi-rank training.
world_size, world_rank = setup_ddp()
voi = config["NeuralNetwork"]["Variables_of_interest"]
# Create a MarginalDiffusionProcess object.
dp = MarginalDiffusionProcess(
args.diffusion_steps, marg_dist=du.get_marg_dist(root_path=args.data_path)
)
# Create a training transform function for the QM9 dataset.
train_tform = get_train_transform(dp)
# Load the QM9 dataset from torch with the pre-transform, pre-filter, and train transform.
# TODO should be generalized, a la fine tuning
dataset = torch_geometric.datasets.QM9(root=args.data_path, transform=train_tform)
# Limit the number of samples if specified.
if args.samples != None:
dataset = dataset[: args.samples]
else:
print("Training on Full Dataset")
# TODO modify config to move Training outside of Neural Network
# Split into train, validation, and test sets.
train, val, test = hydragnn.preprocess.split_dataset(
dataset, config["NeuralNetwork"]["Training"]["perc_train"], False
)
# Create dataloaders for PyTorch training
(
train_loader,
val_loader,
test_loader,
) = hydragnn.preprocess.create_dataloaders(
train, val, test, config["NeuralNetwork"]["Training"]["batch_size"]
)
# Update the config with the dataloaders.
config = config_utils.update_config(config, train_loader, val_loader, test_loader)
# Save the config with all the updated stuff
wandb.init(project="graph diffusion model", config=config)
# models_path = './models'
# with open(os.path.join(models_path,args.run_name,'config.json'), 'w') as json_file:
# json.dump(config, json_file, indent=4)
# Create the model from the config specifications
model = hydragnn.models.create_model_config(
config=config["NeuralNetwork"],
verbosity=verbosity,
)
# Distribute the model across ranks (if necessary).
# model = get_distributed_model(model, verbosity)
# Define training optimizer and scheduler
learning_rate = config["NeuralNetwork"]["Training"]["Optimizer"]["learning_rate"]
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5, min_lr=0.00001
)
# TODO move this to train_utils.py, name specific
def loss(outputs, targets):
l1 = torch.nn.functional.mse_loss(outputs[1], targets[1])
l2 = torch.nn.functional.cross_entropy(outputs[0], targets[0])
return 2 * l1 + l2
# Run training with the given model and dataset.
model = train_model(
model,
loss,
optimizer,
train_loader,
config["NeuralNetwork"]["Training"]["num_epoch"],
logger=wandb.run,
)
# save the model
# torch.save(model.module.state_dict(), os.path.join(models_path,args.run_name,'model.pth'))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Create default log name if not specified.
default_log_name = "test"
parser.add_argument("-s", "--samples", type=int)
parser.add_argument("-ds", "--diffusion_steps", type=int, default=100)
parser.add_argument("-l", "--run_name", type=str, default=default_log_name)
parser.add_argument(
"-c", "--config_path", type=str, default="examples/qm9/qm9_marginal.json"
)
parser.add_argument("-d", "--data_path", type=str, default="examples/qm9/dataset")
# Store the arguments in args.
args = parser.parse_args()
train(args)