-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
143 lines (133 loc) · 7.52 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import sys
import argparse
import torch
import odak
from torch.utils.data import DataLoader
import utils
import time
__title__ = 'Holobeam'
def main(
settings_filename = './settings/jasper.txt',
weights_filename = None,
input_filename = None,
mode = 'train'
):
parser = argparse.ArgumentParser(description=__title__)
parser.add_argument(
'--settings',
type = argparse.FileType('r'),
help = 'Filename for the settings file. Default is {}.'.format(settings_filename)
)
parser.add_argument(
'--weights',
type = argparse.FileType('r'),
help = 'Filename for the weights file.'
)
parser.add_argument(
'--input',
type = argparse.FileType('r'),
help = 'Filename for an input data to estimate. Any RGB png file is good or a pt file generated by GitHub:complight/realistic_defocus.'
)
args = parser.parse_args()
if not isinstance(args.settings, type(None)):
settings_filename = str(args.settings.name)
if not isinstance(args.weights, type(None)):
weights_filename = str(args.weights.name)
if not isinstance(args.input, type(None)):
input_filename = str(args.input.name)
settings = odak.tools.load_dictionary(settings_filename)
device = torch.device(settings["general"]["device"])
odak.tools.check_directory(settings["general"]["output directory"])
model = odak.learn.wave.holobeam_multiholo(
n_input = settings["model"]["number of input channels"],
n_hidden = settings["model"]["number of hidden channels"],
n_output = settings["model"]["number of output channels"],
device = device
)
if not isinstance(weights_filename, type(None)):
model.load_weights(weights_filename)
if not isinstance(input_filename, type(None)):
if input_filename.endswith('.png'):
input_data = odak.learn.tools.load_image(input_filename, normalizeby = 255., torch_style = True).to(device).unsqueeze(0)
elif input_filename.endswith('.pt'):
input_data = odak.learn.tools.torch_load(input_filename)["target"][1].to(device).unsqueeze(0)
else:
print('Bad input file extension. Please provide PNG or PT files.')
sys.exit()
input_data = (input_data - 0.5) * 2
model_input = torch.zeros(
input_data.shape[0],
settings["model"]["number of input channels"],
input_data.shape[-2],
input_data.shape[-1]
).to(device)
model_input[:, 0] = input_data[:, 0]
resolution = [model_input.shape[-2], model_input.shape[-1]]
odak.learn.tools.save_image(
'{}/estimate_input.png'.format(settings["general"]["output directory"]),
(model_input[0, 0] / 2.) + 0.5,
cmin=0.,
cmax=1.
)
torch.no_grad()
estimate = model.forward(model_input.detach().clone(), test = True).detach()
odak.learn.tools.save_image(
'{}/estimate_phase.png'.format(settings["general"]["output directory"]),
estimate[0, 0],
cmin = 0.,
cmax = 1.
)
propagator = odak.learn.wave.propagator(
wavelengths = settings['hologram']['wavelength'],
pixel_pitch = settings['hologram']['pixel pitch'],
resolution = resolution,
aperture_size = settings['hologram']['pinhole size'],
number_of_frames = 1,
number_of_depth_layers = settings['hologram']['number of planes'],
volume_depth = settings['hologram']['volume depth'],
image_location_offset = settings['hologram']['location offset'],
propagation_type = settings['hologram']['propagation type'],
propagator_type = settings['hologram']['propagator type'],
back_and_forth_distance = settings['hologram']['back and forth distance'],
method = 'conventional',
device = device
)
reconstruction_intensities = propagator.reconstruct(estimate * 2 * odak.pi)
hologram_phase = (input_data[0, 2].unsqueeze(0).unsqueeze(0) / 2. + 0.5) * 2 * odak.pi
reconstruction_intensities_ground_truth = propagator.reconstruct(hologram_phase)
for i in range(settings["hologram"]["number of planes"]):
odak.learn.tools.save_image(
'{}/estimate_reconstruction_{:04d}.png'.format(settings["general"]["output directory"], i),
reconstruction_intensities[0, i, 0],
cmin = 0.,
cmax = 1.
)
odak.learn.tools.save_image(
'{}/ground_truth_reconstruction_{:04d}.png'.format(settings["general"]["output directory"], i),
reconstruction_intensities_ground_truth[0, i, 0],
cmin = 0.,
cmax = 1.
)
sys.exit()
train_dataset = utils.hologram_dataset(
directory=settings['train dataset']['directory'],
device=device
)
train_dataloader = DataLoader(train_dataset, batch_size = 1, shuffle = settings['train dataset']['shuffle'])
weights_filename = '{}/weights.pt'.format(settings["general"]["output directory"])
try:
model.fit(
train_dataloader,
number_of_epochs = settings["model"]["number of epochs"],
learning_rate = settings["model"]["learning rate"],
directory = settings["general"]["output directory"],
save_at_every = settings["model"]["save at every"]
)
odak.tools.check_directory(settings["general"]["output directory"])
model.save_weights(filename = weights_filename)
except:
odak.tools.check_directory(settings["general"]["output directory"])
model.save_weights(filename = weights_filename)
print('Training exited and weights are saved to {}'.format(weights_filename))
if __name__ == '__main__':
sys.exit(main())