-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathTrain_Firststep.py
157 lines (127 loc) · 6.45 KB
/
Train_Firststep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from Train_options import Options
from SuperPoint import SuperPoint
from Databases import Database
import Utils
from Utils import LogText, CheckPaths
from FanClass import FAN_Model
import resource
import imgaug.augmenters as iaa
import torch.nn as nn
import yaml
def main():
step=1
experiment_options=Options()
global args
# config parameters
args = experiment_options.args
experiment_name=args.experiment_name
dataset_name = args.dataset_name
number_of_workers = args.num_workers
resume =args.resume
hyperparameters=experiment_options.GetHyperparameters(step,dataset_name)
# training parameters
batchSize = hyperparameters.batchSize
weight_decay = hyperparameters.weight_decay
lr = hyperparameters.lr
number_of_clusters = hyperparameters.number_of_clusters
number_of_clustering_rounds=hyperparameters.number_of_clustering_rounds
nms_thres_superpoint=hyperparameters.nms_thres_superpoint
confidence_thres_superpoint=hyperparameters.confidence_thres_superpoint
use_box=hyperparameters.use_box
remove_superpoint_outliers_percentage=hyperparameters.remove_superpoint_outliers_percentage
training_iterations_before_first_clustering=hyperparameters.training_iterations_before_first_clustering
confidence_thres_FAN=hyperparameters.confidence_thres_FAN
UseScales=hyperparameters.UseScales
RemoveBackgroundClusters=hyperparameters.RemoveBackgroundClusters
#load paths
with open('paths/main.yaml') as file:
paths = yaml.load(file, Loader=yaml.FullLoader)
CheckPaths(paths,dataset_name)
log_path=paths['log_path']
path_to_superpoint_checkpoint=paths['path_to_superpoint_checkpoint']
#This funcion will create the directories /Logs and a /CheckPoints at log_path
Utils.initialize_log_dirs(experiment_name,log_path)
LogText(f"Experiment Name {experiment_name}\n"
f"Database {dataset_name}\n"
"Training Parameters: \n"
f"Batch size {batchSize} \n"
f"Learning rate {lr} \n"
f"Weight Decay {weight_decay} \n"
f"Training iterations before first clustering {training_iterations_before_first_clustering} \n"
f"Number of clustering rounds {number_of_clustering_rounds} \n"
f"FAN detection threshold {confidence_thres_FAN} \n"
f"Number of Clusters {number_of_clusters} \n"
f"Outlier removal {remove_superpoint_outliers_percentage} \n"
, experiment_name, log_path)
LogText("Training of First step begins", experiment_name,log_path)
#augmentations for first step
augmentations = iaa.Sequential([
iaa.Sometimes(0.3,
iaa.GaussianBlur(sigma=(0, 0.5))
),
iaa.ContrastNormalization((0.85, 1.3)),
iaa.Sometimes(0.5,
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5)
)
,
iaa.Multiply((0.9, 1.1), per_channel=0.2),
iaa.Sometimes(0.3,
iaa.MultiplyHueAndSaturation((0.5, 1.5), per_channel=True),
),
iaa.Affine(
scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)},
rotate=(-40, 40),
)
])
#selection of the dataloading function
superpoint_dataloading_function=Database.get_image_superpoint
if(UseScales):
superpoint_dataloading_function=Database.get_image_superpoint_multiple_scales
superpoint= SuperPoint(number_of_clusters,
confidence_thres_superpoint,
nms_thres_superpoint,
path_to_superpoint_checkpoint,
experiment_name,
log_path,
remove_superpoint_outliers_percentage,
use_box,
UseScales,
RemoveBackgroundClusters,
)
superpoint_dataset=Database( dataset_name, number_of_clusters,
function_for_dataloading=superpoint_dataloading_function, augmentations=augmentations,use_box=use_box)
dataloader = DataLoader(superpoint_dataset, batch_size=batchSize, shuffle=False, num_workers=number_of_workers,
drop_last=True)
criterion = nn.MSELoss().cuda()
FAN = FAN_Model(number_of_clusters, criterion, experiment_name,confidence_thres_FAN, log_path,step)
FAN.init_firststep(lr,weight_decay,number_of_clusters,training_iterations_before_first_clustering)
if(resume):
path_to_checkpoint,path_to_keypoints=Utils.GetPathsResumeFirstStep(experiment_name,log_path)
if(path_to_checkpoint is not None):
FAN.load_trained_fiststep_model(path_to_checkpoint)
keypoints=Utils.load_keypoints(path_to_keypoints)
else:
#get initial pseudo-groundtruth by applying superpoint on the training data
keypoints=superpoint.CreateInitialPseudoGroundtruth(dataloader)
dataset = Database( dataset_name, FAN.number_of_clusters, image_keypoints=keypoints,function_for_dataloading=Database.get_FAN_firstStep_train, augmentations=augmentations)
dataloader = DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=number_of_workers,drop_last=True)
database_clustering = Database(dataset_name, FAN.number_of_clusters,function_for_dataloading=Database.get_FAN_inference)
dataloader_clustering = DataLoader(database_clustering, batch_size=batchSize, shuffle=False,num_workers=number_of_workers, drop_last=True)
for i in range(number_of_clustering_rounds):
FAN.Train_step1(dataloader)
keypoints=FAN.Update_pseudoLabels(dataloader_clustering,keypoints)
dataset = Database(dataset_name, FAN.number_of_clusters, image_keypoints=keypoints,
function_for_dataloading=Database.get_FAN_firstStep_train, augmentations=augmentations)
dataloader = DataLoader(dataset, batch_size=batchSize, shuffle=True, num_workers=number_of_workers,
drop_last=True)
if __name__=="__main__":
torch.manual_seed(1993)
torch.cuda.manual_seed_all(1993)
np.random.seed(1993)
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
main()