From 8a1a8f1baf73fbdf96869779092914f65dff3b8b Mon Sep 17 00:00:00 2001 From: zhanglang1860 <42694007+zhanglang1860@users.noreply.github.com> Date: Sat, 9 Feb 2019 15:17:24 +0800 Subject: [PATCH 1/5] zhangyicheng change download file --- .idea/other.xml | 7 ++ .idea/vcs.xml | 6 ++ Cifar10DataReader.py | 91 ++++++++++++++++++++++++++ config.py | 15 +++-- download.py | 148 +++++++++++++++++++++++++----------------- trainer.py | 6 +- visualize_training.py | 2 +- 7 files changed, 206 insertions(+), 69 deletions(-) create mode 100644 .idea/other.xml create mode 100644 .idea/vcs.xml create mode 100644 Cifar10DataReader.py diff --git a/.idea/other.xml b/.idea/other.xml new file mode 100644 index 0000000..640fd80 --- /dev/null +++ b/.idea/other.xml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Cifar10DataReader.py b/Cifar10DataReader.py new file mode 100644 index 0000000..d88dbe8 --- /dev/null +++ b/Cifar10DataReader.py @@ -0,0 +1,91 @@ +import pickle +import numpy as np +import os + + +class Cifar10DataReader(): + def __init__(self, cifar_folder, onehot=True): + self.cifar_folder = cifar_folder + self.onehot = onehot + self.data_index = 1 + self.read_next = True + self.data_label_train = None + self.data_label_test = None + self.batch_index = 0 + + def unpickle(self, f): + fo = open(f, 'rb') + d = pickle.load(fo, encoding='bytes') + fo.close() + return d + + def next_train_data(self, batch_size=100): + assert 10000 % batch_size == 0, "10000%batch_size!=0" + rdata = None + rlabel = None + if self.read_next: + f = os.path.join(self.cifar_folder, "data_batch_%s" % (self.data_index)) + print('read: %s' % f) + dic_train = self.unpickle(f) + self.data_label_train = list(zip(dic_train[b'data'], dic_train[b'labels'])) # label 0~9 + np.random.shuffle(self.data_label_train) + + self.read_next = False + if self.data_index == 5: + self.data_index = 1 + else: + self.data_index += 1 + + if self.batch_index < len(self.data_label_train) // batch_size: + # print self.batch_index + datum = self.data_label_train[self.batch_index * batch_size:(self.batch_index + 1) * batch_size] + self.batch_index += 1 + rdata, rlabel = self._decode(datum, self.onehot) + else: + self.batch_index = 0 + self.read_next = True + return self.next_train_data(batch_size=batch_size) + + return rdata, rlabel + + def _decode(self, datum, onehot): + rdata = list(); + rlabel = list() + if onehot: + for d, l in datum: + rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3])) + hot = np.zeros(10) + hot[int(l)] = 1 + rlabel.append(hot) + else: + for d, l in datum: + rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3])) + rlabel.append(int(l)) + return rdata, rlabel + + def next_test_data(self, batch_size=100): + if self.data_label_test is None: + f = os.path.join(self.cifar_folder, "test_batch") + print('read: %s' % f) + dic_test = self.unpickle(f) + data = dic_test[b'data'] + labels = dic_test[b'labels'] # 0~9 + self.data_label_test = list(zip(data, labels)) + + np.random.shuffle(self.data_label_test) + datum = self.data_label_test[0:batch_size] + + return self._decode(datum, self.onehot) + + +if __name__ == "__main__": + dr = Cifar10DataReader(cifar_folder=r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets\cifar10\cifar-10-batches-py") + + import matplotlib.pyplot as plt + d, l = dr.next_test_data() + print(np.shape(d), np.shape(l)) + plt.imshow(d[0]) + plt.show() + for i in range(600): + d, l = dr.next_train_data(batch_size=100) + print(np.shape(d), np.shape(l)) diff --git a/config.py b/config.py index fd87b3c..234d99a 100644 --- a/config.py +++ b/config.py @@ -6,9 +6,9 @@ def argparser(is_train=True): - def str2bool(v): return v.lower() == 'true' + parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -16,7 +16,7 @@ def str2bool(v): parser.add_argument('--prefix', type=str, default='default') parser.add_argument('--train_dir', type=str) parser.add_argument('--checkpoint', type=str, default=None) - parser.add_argument('--dataset', type=str, default='CIFAR10', + parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'SVHN', 'CIFAR10']) parser.add_argument('--dump_result', type=str2bool, default=False) # Model @@ -36,7 +36,7 @@ def str2bool(v): parser.add_argument('--test_sample_step', type=int, default=100) parser.add_argument('--output_save_step', type=int, default=1000) # learning - parser.add_argument('--max_sample', type=int, default=5000, + parser.add_argument('--max_sample', type=int, default=5000, help='num of samples the model can see') parser.add_argument('--max_training_steps', type=int, default=10000000) parser.add_argument('--learning_rate_g', type=float, default=1e-4) @@ -51,16 +51,17 @@ def str2bool(v): config = parser.parse_args() - dataset_path = os.path.join('./datasets', config.dataset.lower()) + dataset_path = os.path.join(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets", + config.dataset.lower()) dataset_train, dataset_test = dataset.create_default_splits(dataset_path) - + print("step2") img, label = dataset_train.get_data(dataset_train.ids[0]) + print("step3") config.h = img.shape[0] config.w = img.shape[1] config.c = img.shape[2] - config.num_class = label.shape[0] + config.num_class = label.shape[0] # --- create model --- model = Model(config, debug_information=config.debug, is_train=is_train) - return config, model, dataset_train, dataset_test diff --git a/download.py b/download.py index 90c702e..e8ac1e8 100644 --- a/download.py +++ b/download.py @@ -11,7 +11,6 @@ def prepare_h5py(train_image, train_label, test_image, test_label, data_dir, shape=None): - image = np.concatenate((train_image, test_image), axis=0).astype(np.uint8) label = np.concatenate((train_label, test_label), axis=0).astype(np.uint8) @@ -24,35 +23,37 @@ def prepare_h5py(train_image, train_label, test_image, test_label, data_dir, sha bar.start() f = h5py.File(os.path.join(data_dir, 'data.hdf5'), 'w') - data_id = open(os.path.join(data_dir,'id.txt'), 'w') + data_id = open(os.path.join(data_dir, 'id.txt'), 'w') for i in range(image.shape[0]): - if i%(image.shape[0]/100)==0: - bar.update(i/(image.shape[0]/100)) + if i % (image.shape[0] / 100) == 0: + bar.update(i / (image.shape[0] / 100)) grp = f.create_group(str(i)) - data_id.write(str(i)+'\n') + data_id.write(str(i) + '\n') if shape: grp['image'] = np.reshape(image[i], shape, order='F') else: grp['image'] = image[i] label_vec = np.zeros(10) - label_vec[label[i]%10] = 1 + label_vec[label[i] % 10] = 1 grp['label'] = label_vec.astype(np.bool) bar.finish() f.close() data_id.close() return + def check_file(data_dir): if os.path.exists(data_dir): if os.path.isfile(os.path.join('data.hdf5')) and \ - os.path.isfile(os.path.join('id.txt')): + os.path.isfile(os.path.join('id.txt')): return True else: os.mkdir(data_dir) return False + def download_mnist(download_path): data_dir = os.path.join(download_path, 'mnist') @@ -62,42 +63,46 @@ def download_mnist(download_path): data_url = 'http://yann.lecun.com/exdb/mnist/' keys = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', - 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz'] - - for k in keys: - url = (data_url+k).format(**locals()) - target_path = os.path.join(data_dir, k) - cmd = ['curl', url, '-o', target_path] - print('Downloading ', k) - subprocess.call(cmd) - cmd = ['gzip', '-d', target_path] - print('Unzip ', k) - subprocess.call(cmd) - + 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz'] + # keys = [ 'train-labels-idx1-ubyte.gz'] + + # for k in keys: + # url = (data_url+k).format(**locals()) + # target_path = os.path.join(data_dir, k) + # print(target_path) + # cmd = ['curl', url, '-o', target_path] + # print('Downloading ', k) + # subprocess.call(cmd) + # # cmd = ['gzip', '-d', target_path] + # cmd = ['7z', 'e', target_path] + # print('Unzip ', k) + # # subprocess.call(cmd) + # print('OK ', k) num_mnist_train = 60000 num_mnist_test = 10000 - fd = open(os.path.join(data_dir,'train-images-idx3-ubyte')) - loaded = np.fromfile(file=fd,dtype=np.uint8) - train_image = loaded[16:].reshape((num_mnist_train,28,28,1)).astype(np.float) + fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte')) + loaded = np.fromfile(file=fd, dtype=np.uint8) + train_image = loaded[16:].reshape((num_mnist_train, 28, 28, 1)).astype(np.float) - fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte')) - loaded = np.fromfile(file=fd,dtype=np.uint8) + fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte')) + loaded = np.fromfile(file=fd, dtype=np.uint8) train_label = np.asarray(loaded[8:].reshape((num_mnist_train)).astype(np.float)) - fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte')) - loaded = np.fromfile(file=fd,dtype=np.uint8) - test_image = loaded[16:].reshape((num_mnist_test,28,28,1)).astype(np.float) + fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte')) + loaded = np.fromfile(file=fd, dtype=np.uint8) + test_image = loaded[16:].reshape((num_mnist_test, 28, 28, 1)).astype(np.float) - fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte')) - loaded = np.fromfile(file=fd,dtype=np.uint8) + fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte')) + loaded = np.fromfile(file=fd, dtype=np.uint8) test_label = np.asarray(loaded[8:].reshape((num_mnist_test)).astype(np.float)) prepare_h5py(train_image, train_label, test_image, test_label, data_dir) - for k in keys: - cmd = ['rm', '-f', os.path.join(data_dir, k[:-3])] - subprocess.call(cmd) + # for k in keys: + # cmd = ['rm', '-f', os.path.join(data_dir, k[:-3])] + # subprocess.call(cmd) + def download_svhn(download_path): data_dir = os.path.join(download_path, 'svhn') @@ -123,31 +128,32 @@ def svhn_loader(url, path): prepare_h5py(np.transpose(train_image, (3, 0, 1, 2)), train_label, np.transpose(test_image, (3, 0, 1, 2)), test_label, data_dir) - cmd = ['rm', '-f', os.path.join(data_dir, '*.mat')] - subprocess.call(cmd) + # cmd = ['rm', '-f', os.path.join(data_dir, '*.mat')] + # subprocess.call(cmd) + def download_cifar10(download_path): data_dir = os.path.join(download_path, 'cifar10') # cifar file loader def unpickle(file): - import cPickle + import pickle with open(file, 'rb') as fo: - dict = cPickle.load(fo) + dict = pickle.load(fo, encoding='bytes') return dict if check_file(data_dir): print('CIFAR was downloaded.') return - data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' - k = 'cifar-10-python.tar.gz' - target_path = os.path.join(data_dir, k) - print(target_path) - cmd = ['curl', data_url, '-o', target_path] - print('Downloading CIFAR10') - subprocess.call(cmd) - tarfile.open(target_path, 'r:gz').extractall(data_dir) + # data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + # k = 'cifar-10-python.tar.gz' + # target_path = os.path.join(data_dir, k) + # print(target_path) + # cmd = ['curl', data_url, '-o', target_path] + # print('Downloading CIFAR10') + # subprocess.call(cmd) + # tarfile.open(target_path, 'r:gz').extractall(data_dir) num_cifar_train = 50000 num_cifar_test = 10000 @@ -156,34 +162,60 @@ def unpickle(file): train_image = [] train_label = [] for i in range(5): - fd = os.path.join(target_path, 'data_batch_'+str(i+1)) + fd = os.path.join(target_path, 'data_batch_' + str(i + 1)) dict = unpickle(fd) - train_image.append(dict['data']) - train_label.append(dict['labels']) - train_image = np.reshape(np.stack(train_image, axis=0), [num_cifar_train, 32*32*3]) + sys.stdout = Logger(r'C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets\cifar10\a.txt') + print(dict) + print('------------------') + + train_image.append(dict[b'data']) + train_label.append(dict[b'labels']) + + + train_image = np.reshape(np.stack(train_image, axis=0), [num_cifar_train, 32 * 32 * 3]) train_label = np.reshape(np.array(np.stack(train_label, axis=0)), [num_cifar_train]) fd = os.path.join(target_path, 'test_batch') dict = unpickle(fd) - test_image = np.reshape(dict['data'], [num_cifar_test, 32*32*3]) - test_label = np.reshape(dict['labels'], [num_cifar_test]) + test_image = np.reshape(dict[b'data'], [num_cifar_test, 32 * 32 * 3]) + test_label = np.reshape(dict[b'labels'], [num_cifar_test]) prepare_h5py(train_image, train_label, test_image, test_label, data_dir, [32, 32, 3]) - cmd = ['rm', '-f', os.path.join(data_dir, 'cifar-10-python.tar.gz')] - subprocess.call(cmd) - cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')] - subprocess.call(cmd) + # cmd = ['rm', '-f', os.path.join(data_dir, 'cifar-10-python.tar.gz')] + # subprocess.call(cmd) + # cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')] + # subprocess.call(cmd) + + +import sys +import os + + +class Logger(object): + def __init__(self, filename="Default.log"): + self.terminal = sys.stdout + self.log = open(filename, "a") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + pass + + + if __name__ == '__main__': args = parser.parse_args() - path = './datasets' + path = r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets" if not os.path.exists(path): os.mkdir(path) if 'MNIST' in args.datasets: - download_mnist('./datasets') + download_mnist(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets") if 'SVHN' in args.datasets: - download_svhn('./datasets') + download_svhn(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets") if 'CIFAR10' in args.datasets: - download_cifar10('./datasets') + download_cifar10(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets") diff --git a/trainer.py b/trainer.py index 3186400..779bdd1 100644 --- a/trainer.py +++ b/trainer.py @@ -33,7 +33,7 @@ def __init__(self, config, model, dataset, dataset_test): os.makedirs(self.train_dir) log.infov("Train Dir: %s", self.train_dir) - + print("step2") # --- input ops --- self.batch_size = config.batch_size @@ -205,9 +205,9 @@ def log_step_message(self, step, accuracy, d_loss, g_loss, ) def main(): - + print("step0") config, model, dataset_train, dataset_test = argparser(is_train=True) - + print("step1") trainer = Trainer(config, model, dataset_train, dataset_test) log.warning("dataset: %s, learning_rate_g: %f, learning_rate_d: %f", diff --git a/visualize_training.py b/visualize_training.py index 88e0ad8..ced515d 100644 --- a/visualize_training.py +++ b/visualize_training.py @@ -27,5 +27,5 @@ II.append(I) II = np.stack(II) -print II.shape +print(II.shape) imageio.mimsave(args.output_file, II, fps=5) From e1bb513dee59f3b7e242d7bfc49d4e437098054d Mon Sep 17 00:00:00 2001 From: zhangyicheng Date: Tue, 12 Feb 2019 23:39:34 +0800 Subject: [PATCH 2/5] path name change to ubantu --- config.py | 2 +- download.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config.py b/config.py index 234d99a..a37d35e 100644 --- a/config.py +++ b/config.py @@ -51,7 +51,7 @@ def str2bool(v): config = parser.parse_args() - dataset_path = os.path.join(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets", + dataset_path = os.path.join(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets", config.dataset.lower()) dataset_train, dataset_test = dataset.create_default_splits(dataset_path) print("step2") diff --git a/download.py b/download.py index e8ac1e8..02b4cfb 100644 --- a/download.py +++ b/download.py @@ -165,7 +165,7 @@ def unpickle(file): fd = os.path.join(target_path, 'data_batch_' + str(i + 1)) dict = unpickle(fd) - sys.stdout = Logger(r'C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets\cifar10\a.txt') + sys.stdout = Logger(r'/home/zhou/Project/tf/SSGAN-Tensorflow/a.txt') print(dict) print('------------------') @@ -210,12 +210,12 @@ def flush(self): if __name__ == '__main__': args = parser.parse_args() - path = r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets" + path = r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets" if not os.path.exists(path): os.mkdir(path) if 'MNIST' in args.datasets: - download_mnist(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets") + download_mnist(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets") if 'SVHN' in args.datasets: - download_svhn(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets") + download_svhn(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets") if 'CIFAR10' in args.datasets: - download_cifar10(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets") + download_cifar10(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets") From d8dd4a585faffe8cffe18e8dfb38b164b9cbe642 Mon Sep 17 00:00:00 2001 From: zhangyicheng1986 Date: Sun, 24 Feb 2019 19:16:44 +0800 Subject: [PATCH 3/5] inint aaa --- .idea/SSGAN-Tensorflow.iml | 12 ++++++++++++ .idea/misc.xml | 4 ++++ .idea/modules.xml | 8 ++++++++ aaaa.py | 1 + 4 files changed, 25 insertions(+) create mode 100644 .idea/SSGAN-Tensorflow.iml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 aaaa.py diff --git a/.idea/SSGAN-Tensorflow.iml b/.idea/SSGAN-Tensorflow.iml new file mode 100644 index 0000000..6f63a63 --- /dev/null +++ b/.idea/SSGAN-Tensorflow.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..12f37ed --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..0dd1cf6 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/aaaa.py b/aaaa.py new file mode 100644 index 0000000..086876d --- /dev/null +++ b/aaaa.py @@ -0,0 +1 @@ +hggk \ No newline at end of file From f5f71fe9f893ddedc98c2e911d11bb1d6187cb6a Mon Sep 17 00:00:00 2001 From: zhangyicheng1986 Date: Sun, 24 Feb 2019 19:18:37 +0800 Subject: [PATCH 4/5] inint aaa --- aaaa.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 aaaa.py diff --git a/aaaa.py b/aaaa.py deleted file mode 100644 index 086876d..0000000 --- a/aaaa.py +++ /dev/null @@ -1 +0,0 @@ -hggk \ No newline at end of file From 5fa3faa1328ae8d9e0fc503e5b8e1dd9f7fac2e3 Mon Sep 17 00:00:00 2001 From: zhangyicheng1986 Date: Tue, 26 Feb 2019 00:42:06 +0800 Subject: [PATCH 5/5] for ubuntu 16.04 and python2.7 --- .idea/SSGAN-Tensorflow.iml | 2 +- .idea/misc.xml | 2 +- Cifar10DataReader.py | 2 +- config.py | 2 +- download.py | 38 +++++++++++++++++++------------------- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/.idea/SSGAN-Tensorflow.iml b/.idea/SSGAN-Tensorflow.iml index 6f63a63..f3dcd9e 100644 --- a/.idea/SSGAN-Tensorflow.iml +++ b/.idea/SSGAN-Tensorflow.iml @@ -2,7 +2,7 @@ - + diff --git a/.idea/misc.xml b/.idea/misc.xml index 12f37ed..1b55951 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/Cifar10DataReader.py b/Cifar10DataReader.py index d88dbe8..0eeaca7 100644 --- a/Cifar10DataReader.py +++ b/Cifar10DataReader.py @@ -79,7 +79,7 @@ def next_test_data(self, batch_size=100): if __name__ == "__main__": - dr = Cifar10DataReader(cifar_folder=r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets\cifar10\cifar-10-batches-py") + dr = Cifar10DataReader(cifar_folder=r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets/cifar10/cifar-10-batches-py") import matplotlib.pyplot as plt d, l = dr.next_test_data() diff --git a/config.py b/config.py index a37d35e..eb73ebd 100644 --- a/config.py +++ b/config.py @@ -51,7 +51,7 @@ def str2bool(v): config = parser.parse_args() - dataset_path = os.path.join(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets", + dataset_path = os.path.join(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets", config.dataset.lower()) dataset_train, dataset_test = dataset.create_default_splits(dataset_path) print("step2") diff --git a/download.py b/download.py index 02b4cfb..a969f73 100644 --- a/download.py +++ b/download.py @@ -137,23 +137,23 @@ def download_cifar10(download_path): # cifar file loader def unpickle(file): - import pickle + import cPickle with open(file, 'rb') as fo: - dict = pickle.load(fo, encoding='bytes') + dict = cPickle.load(fo)#dict = cPickle.load(fo, encoding='bytes') return dict if check_file(data_dir): print('CIFAR was downloaded.') return - # data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' - # k = 'cifar-10-python.tar.gz' - # target_path = os.path.join(data_dir, k) - # print(target_path) - # cmd = ['curl', data_url, '-o', target_path] - # print('Downloading CIFAR10') - # subprocess.call(cmd) - # tarfile.open(target_path, 'r:gz').extractall(data_dir) + data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + k = 'cifar-10-python.tar.gz' + target_path = os.path.join(data_dir, k) + print(target_path) + #cmd = ['curl', data_url, '-o', target_path] + print('Downloading CIFAR10') + #subprocess.call(cmd) + tarfile.open(target_path, 'r:gz').extractall(data_dir) num_cifar_train = 50000 num_cifar_test = 10000 @@ -165,7 +165,7 @@ def unpickle(file): fd = os.path.join(target_path, 'data_batch_' + str(i + 1)) dict = unpickle(fd) - sys.stdout = Logger(r'/home/zhou/Project/tf/SSGAN-Tensorflow/a.txt') + sys.stdout = Logger(r'/home/yc/PycharmProjects/SSGAN-Tensorflow/a.txt') print(dict) print('------------------') @@ -183,10 +183,10 @@ def unpickle(file): prepare_h5py(train_image, train_label, test_image, test_label, data_dir, [32, 32, 3]) - # cmd = ['rm', '-f', os.path.join(data_dir, 'cifar-10-python.tar.gz')] - # subprocess.call(cmd) - # cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')] - # subprocess.call(cmd) + cmd = ['rm', '-f', os.path.join(data_dir, 'cifar-10-python.tar.gz')] + subprocess.call(cmd) + cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')] + subprocess.call(cmd) import sys @@ -210,12 +210,12 @@ def flush(self): if __name__ == '__main__': args = parser.parse_args() - path = r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets" + path = r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets" if not os.path.exists(path): os.mkdir(path) if 'MNIST' in args.datasets: - download_mnist(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets") + download_mnist(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets") if 'SVHN' in args.datasets: - download_svhn(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets") + download_svhn(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets") if 'CIFAR10' in args.datasets: - download_cifar10(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets") + download_cifar10(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets")