-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
162 lines (133 loc) · 5.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from time import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import copy
import argparse
from utils import *
from tqdm import tqdm
import os
import json
from unet import UNet
#Get arguments for network configuration from command line
parser = argparse.ArgumentParser(description="UNet model training loop")
parser.add_argument("-n", "--num_layers", type=int, default=4, help="No of encoder and decoder layers in the UNet network")
parser.add_argument("-b", "--bottleneck", type=str, default="conv", help="Choose architecture of bottleneck layer, conv")
parser.add_argument("-f", "--features", nargs='+', default=["depth", "normal", "relative_normal", "albedo", "roughness"], help="Features to include")
parser.add_argument("-t", "--tag", type=str, default="", help="Tag to be added for while saving files")
parser.add_argument("-a", "--alpha", type=float, default=0.5, help="Coefficient for L1 loss")
args = parser.parse_args()
# Paths to the text files
split_file_folder = "data"
train_txt = f"{split_file_folder}/train.txt"
val_txt = f"{split_file_folder}/val.txt"
test_txt = f"{split_file_folder}/test.txt"
inference_txt = f"{split_file_folder}/inference.txt"
# Paths to the image folders
data_path = "data/raw_data"
train_loader, val_loader, test_loader, _, num_features = create_datasets(train_txt, val_txt, test_txt, inference_txt, data_path, args.features)
device = torch.device("cuda")
# Define the model, loss function, and optimizer
model = UNet(num_features, 3, args.num_layers, args.bottleneck).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training parameters
num_epochs = 1
patience = 5 # Early stopping patience
best_loss = float('inf')
patience_counter = 0
l2_loss = 0
l1_epochs = 0
l2_epochs = 0
training_loss = []
validation_loss = []
validation_psnr = []
# For saving the best model
best_model_wts = copy.deepcopy(model.state_dict())
model.train()
# Training loop
for epoch in range(num_epochs):
# adjust_learning_rate(optimizer, epoch)
running_loss = 0.0
t = time()
for noisy_image, clean_image in tqdm(train_loader):
# Move tensors to the appropriate device
noisy_image, clean_image = noisy_image.to(device), clean_image.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(noisy_image)
if l2_loss == 0:
loss = args.alpha * l1_norm(outputs, clean_image) + (1-args.alpha) * HFEN(outputs, clean_image)
elif l2_loss == 1:
loss = args.alpha * l2_norm(outputs, clean_image) + (1-args.alpha) * HFEN(outputs, clean_image)
# Backward pass and optimization
loss.backward()
optimizer.step()
# Accumulate the loss
running_loss += loss.item() * noisy_image.size(0)
# Calculate training loss
epoch_loss = running_loss / len(train_loader.dataset)
training_loss.append(epoch_loss)
print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}, Time: {time()-t}s')
# Validation phase
model.eval()
val_loss = 0.0
psnr_values = []
with torch.no_grad():
for noisy_image, clean_image in tqdm(val_loader):
noisy_image, clean_image = noisy_image.to(device), clean_image.to(device)
outputs = model(noisy_image)
loss = args.alpha * l1_norm(outputs, clean_image) + (1-args.alpha) * HFEN(outputs, clean_image)
val_loss += loss.item() * noisy_image.size(0)
psnr_values += [psnr(outputs[i], clean_image[i]) for i in range(outputs.size(0))]
# Calculate validation loss
val_loss /= len(val_loader.dataset)
validation_loss.append(val_loss)
validation_psnr.append(np.mean(psnr_values))
print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}, Validation PSNR: {validation_loss[-1]}')
# Check for early stopping
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = copy.deepcopy(model.state_dict())
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
if l2_loss == 0:
print("Early stopping triggered, switch to l2 loss for remaining epochs")
l2_loss = 1
patience_counter = 0
l1_epochs = epoch
best_loss = float('inf')
elif l2_loss == 1:
print("Early stopping triggered, exiting")
l2_epochs = epoch
break
# Load the best model weights
model.load_state_dict(best_model_wts)
os.makedirs("results", exist_ok=True)
feature_string = '_'.join([x[:3] for x in args.features])
if args.tag != "":
experiment_name = f"{args.tag}_"
else:
experiment_name = ""
experiment_name += f"n_{args.num_layers}_alpha_{args.alpha:.2f}_feat_{feature_string}"
folder = f"results/{experiment_name}"
os.makedirs(folder, exist_ok=True)
# Save the model
torch.save(model.state_dict(), f"{folder}/checkpoint.pth")
config = {
"n": args.num_layers,
"features": args.features,
"tag": args.tag,
"alpha": args.alpha,
"l1_epochs": l1_epochs,
"l2_epochs": l2_epochs
}
with open(f"{folder}/config.json", 'w') as fp:
json.dump(config, fp)
open(f"{folder}/training_loss.txt", "w").write("\n".join([str(x) for x in training_loss]))
open(f"{folder}/validation_loss.txt", "w").write("\n".join([str(x) for x in validation_loss]))
open(f"{folder}/validation_psnr.txt", "w").write("\n".join([str(x) for x in validation_psnr]))