-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAD_3DRandomPatch.py
104 lines (81 loc) · 3.4 KB
/
AD_3DRandomPatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import nibabel as nib
import os
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
from PIL import Image
import random
import torch
NON_AX = (1, 2)
NON_COR = (0, 2)
NON_SAG = (0, 1)
class AD_3DRandomPatch(Dataset):
"""labeled Faces in the Wild dataset."""
def __init__(self, root_dir, data_file):
"""
Args:
root_dir (string): Directory of all the images.
data_file (string): File name of the train/test split file.
"""
self.root_dir = root_dir
self.data_file = data_file
def __len__(self):
with open(self.data_file) as df:
summation = sum(1 for line in df)
return summation
def __getitem__(self, idx):
with open(self.data_file) as df:
lines = df.readlines()
lst = lines[idx].split()
img_name = lst[0]
image_path = os.path.join(self.root_dir, img_name)
image = nib.load(image_path)
image_array = resize_image(np.array(image.get_data()), (110, 110, 110))
patch_samples = getRandomPatches(image_array)
patch_dict = {"patch": patch_samples}
return patch_dict
def customToTensor(pic):
if isinstance(pic, np.ndarray):
img = torch.from_numpy(pic)
img = torch.unsqueeze(img,0)
# backward compatibility
return img.float().div(255)
def resize_image(img_array, trg_size):
res = resize(img_array, trg_size, mode='reflect', preserve_range=True, anti_aliasing=False)
# type check
if type(res) != np.ndarray:
raise "type error!"
return res
def getRandomPatches(image_array):
patches = []
mean_ax = np.ndarray.mean(image_array, axis = NON_AX)
mean_cor = np.ndarray.mean(image_array, axis = NON_COR)
mean_sag = np.ndarray.mean(image_array, axis = NON_SAG)
first_ax = int(round(list(mean_ax).index(filter(lambda x: x>0, mean_ax)[0])))
last_ax = int(round(list(mean_ax).index(filter(lambda x: x>0, mean_ax)[-1])))
first_cor = int(round(list(mean_cor).index(filter(lambda x: x>0, mean_cor)[0])))
last_cor = int(round(list(mean_cor).index(filter(lambda x: x>0, mean_cor)[-1])))
first_sag = int(round(list(mean_sag).index(filter(lambda x: x>0, mean_sag)[0])))
last_sag = int(round(list(mean_sag).index(filter(lambda x: x>0, mean_sag)[-1])))
first_ax = first_ax + 5
last_ax = last_ax - 10
ax_samples = [random.randint(first_ax - 3, last_ax - 3) for r in xrange(10000)]
cor_samples = [random.randint(first_cor - 3, last_cor - 3) for r in xrange(10000)]
sag_samples = [random.randint(first_sag - 3, last_sag - 3) for r in xrange(10000)]
for i in range(1000):
ax_i = ax_samples[i]
cor_i = cor_samples[i]
sag_i = sag_samples[i]
patch = image_array[ax_i-3:ax_i+4, cor_i-3:cor_i+4, sag_i-3:sag_i+4]
while (np.ndarray.sum(patch) == 0):
ax_ni = random.randint(first_ax - 3, last_ax - 4)
cor_ni = random.randint(first_cor - 3, last_cor - 4)
sag_ni = random.randint(first_sag - 3, last_sag - 4)
patch = image_array[ax_ni-3:ax_ni+4, cor_ni-3:cor_ni+4, sag_ni-3:sag_ni+4]
patch = patch/1500*255
patch = customToTensor(patch)
patches.append(patch)
return patches
# plt.imshow(array[i][3,:,:], cmap = 'gray')
# plt.savefig('./section.png', dpi=100)