-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathddrdataset.py
130 lines (106 loc) · 4.22 KB
/
ddrdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
from torch.utils import data
from torchvision import transforms as T
from PIL import Image
import torch as t
import csv
from random import shuffle, sample
from numpy.random import choice
import numpy as np
import cv2
import pandas as pd
class DDR_dataset(data.Dataset):
def __init__(self, train=True, val=False, test=False, multi=25):
self.test = test
self.train = train
self.val = val
self.multi = multi
self.path = '/raid/hjl/DivideMix-DDR/DDR_preprocess1024/'
self.imgs = []
if test:
with open(self.path+'test.txt',encoding='utf-8') as file:
for line in file.readlines():
line = line.strip('\n')
img = line.split(' ')[0]
label = int(line.split(' ')[1])
if label != 5:
self.imgs.append([self.path+'preprocess1024_test/'+img, label])
elif val:
with open(self.path+'valid.txt',encoding='utf-8') as file:
for line in file.readlines():
line = line.strip('\n')
img = line.split(' ')[0]
label = int(line.split(' ')[1])
if label != 5:
self.imgs.append([self.path+'preprocess1024_valid/'+img, label])
elif train:
with open(self.path+'train.txt',encoding='utf-8') as file:
for line in file.readlines():
line = line.strip('\n')
img = line.split(' ')[0]
label = int(line.split(' ')[1])
if label != 5:
self.imgs.append([self.path+'preprocess1024_train/'+img, label])
self.imglen = len(self.imgs)
print(self.imglen)
self.nor = T.Normalize( # original
mean=[.426, .298, .213],std=[.277, .203, .169])
data_aug = {
'brightness': 0.4, # how much to jitter brightness
'contrast': 0.4, # How much to jitter contrast
'saturation': 0.4,
'hue': 0.1,
'scale': (0.8, 1.2), # range of size of the origin size cropped
'ratio': (0.8, 1.2), # range of aspect ratio of the origin aspect ratio cropped
'degrees': (-180, 180), # range of degrees to select from
'translate': (0.2, 0.2) # tuple of maximum absolute fraction for horizontal and vertical translations
}
if train:
self.transform = T.Compose([
T.Resize((640,640)),
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
# T.ColorJitter(
# brightness=data_aug['brightness'],
# contrast=data_aug['contrast'],
# saturation=data_aug['saturation'],
# hue=data_aug['hue']
# ),
T.RandomResizedCrop(
size=(512, 512),
scale=data_aug['scale'],
ratio=data_aug['ratio']
),
T.RandomAffine(
degrees=data_aug['degrees'],
# translate=data_aug['translate']
),
# T.RandomGrayscale(0.2),
T.ToTensor(),
self.nor
])
elif val or test:
self.transform = T.Compose([
T.Resize((512,512)),
T.ToTensor(),
self.nor
])
def __getitem__(self, index):
img, label_grad = self.imgs[index]
data = Image.open(img).convert('RGB')
data = self.transform(data)
label_clf = 0 if label_grad == 0 else 1
if self.multi == 25: #cam
# return data, label_clf, label_grad, img # cam
return data, label_clf, label_grad
elif self.multi == 2:
return data, label_clf
elif self.multi == 5:
return data, label_grad
def __len__(self):
return len(self.imgs)
if __name__ == '__main__':
dst = DDR_dataset(train=True,val=False,test=False,multi=5)
for index in range(dst.__len__()):
data, label_grad = dst.__getitem__(index)
print(label_grad)