You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
class USPS(data.Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None):
super(USPS, self).__init__()
self.root = root
self.transform = transform
self.target_transform = target_transform
filename = 'usps.bz2' if train else 'usps.t.bz2'
full_path = os.path.join(self.root, filename)
import bz2
with bz2.open(full_path) as fp:
raw_data = [l.decode().split() for l in fp.readlines()]
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]
self.data = imgs
self.targets = targets
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
Pretrained MNIST model
USPS
DATASET
SVHN
The text was updated successfully, but these errors were encountered: