From 8a1a8f1baf73fbdf96869779092914f65dff3b8b Mon Sep 17 00:00:00 2001
From: zhanglang1860 <42694007+zhanglang1860@users.noreply.github.com>
Date: Sat, 9 Feb 2019 15:17:24 +0800
Subject: [PATCH 1/5] zhangyicheng change download file
---
.idea/other.xml | 7 ++
.idea/vcs.xml | 6 ++
Cifar10DataReader.py | 91 ++++++++++++++++++++++++++
config.py | 15 +++--
download.py | 148 +++++++++++++++++++++++++-----------------
trainer.py | 6 +-
visualize_training.py | 2 +-
7 files changed, 206 insertions(+), 69 deletions(-)
create mode 100644 .idea/other.xml
create mode 100644 .idea/vcs.xml
create mode 100644 Cifar10DataReader.py
diff --git a/.idea/other.xml b/.idea/other.xml
new file mode 100644
index 0000000..640fd80
--- /dev/null
+++ b/.idea/other.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Cifar10DataReader.py b/Cifar10DataReader.py
new file mode 100644
index 0000000..d88dbe8
--- /dev/null
+++ b/Cifar10DataReader.py
@@ -0,0 +1,91 @@
+import pickle
+import numpy as np
+import os
+
+
+class Cifar10DataReader():
+ def __init__(self, cifar_folder, onehot=True):
+ self.cifar_folder = cifar_folder
+ self.onehot = onehot
+ self.data_index = 1
+ self.read_next = True
+ self.data_label_train = None
+ self.data_label_test = None
+ self.batch_index = 0
+
+ def unpickle(self, f):
+ fo = open(f, 'rb')
+ d = pickle.load(fo, encoding='bytes')
+ fo.close()
+ return d
+
+ def next_train_data(self, batch_size=100):
+ assert 10000 % batch_size == 0, "10000%batch_size!=0"
+ rdata = None
+ rlabel = None
+ if self.read_next:
+ f = os.path.join(self.cifar_folder, "data_batch_%s" % (self.data_index))
+ print('read: %s' % f)
+ dic_train = self.unpickle(f)
+ self.data_label_train = list(zip(dic_train[b'data'], dic_train[b'labels'])) # label 0~9
+ np.random.shuffle(self.data_label_train)
+
+ self.read_next = False
+ if self.data_index == 5:
+ self.data_index = 1
+ else:
+ self.data_index += 1
+
+ if self.batch_index < len(self.data_label_train) // batch_size:
+ # print self.batch_index
+ datum = self.data_label_train[self.batch_index * batch_size:(self.batch_index + 1) * batch_size]
+ self.batch_index += 1
+ rdata, rlabel = self._decode(datum, self.onehot)
+ else:
+ self.batch_index = 0
+ self.read_next = True
+ return self.next_train_data(batch_size=batch_size)
+
+ return rdata, rlabel
+
+ def _decode(self, datum, onehot):
+ rdata = list();
+ rlabel = list()
+ if onehot:
+ for d, l in datum:
+ rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
+ hot = np.zeros(10)
+ hot[int(l)] = 1
+ rlabel.append(hot)
+ else:
+ for d, l in datum:
+ rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
+ rlabel.append(int(l))
+ return rdata, rlabel
+
+ def next_test_data(self, batch_size=100):
+ if self.data_label_test is None:
+ f = os.path.join(self.cifar_folder, "test_batch")
+ print('read: %s' % f)
+ dic_test = self.unpickle(f)
+ data = dic_test[b'data']
+ labels = dic_test[b'labels'] # 0~9
+ self.data_label_test = list(zip(data, labels))
+
+ np.random.shuffle(self.data_label_test)
+ datum = self.data_label_test[0:batch_size]
+
+ return self._decode(datum, self.onehot)
+
+
+if __name__ == "__main__":
+ dr = Cifar10DataReader(cifar_folder=r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets\cifar10\cifar-10-batches-py")
+
+ import matplotlib.pyplot as plt
+ d, l = dr.next_test_data()
+ print(np.shape(d), np.shape(l))
+ plt.imshow(d[0])
+ plt.show()
+ for i in range(600):
+ d, l = dr.next_train_data(batch_size=100)
+ print(np.shape(d), np.shape(l))
diff --git a/config.py b/config.py
index fd87b3c..234d99a 100644
--- a/config.py
+++ b/config.py
@@ -6,9 +6,9 @@
def argparser(is_train=True):
-
def str2bool(v):
return v.lower() == 'true'
+
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -16,7 +16,7 @@ def str2bool(v):
parser.add_argument('--prefix', type=str, default='default')
parser.add_argument('--train_dir', type=str)
parser.add_argument('--checkpoint', type=str, default=None)
- parser.add_argument('--dataset', type=str, default='CIFAR10',
+ parser.add_argument('--dataset', type=str, default='CIFAR10',
choices=['MNIST', 'SVHN', 'CIFAR10'])
parser.add_argument('--dump_result', type=str2bool, default=False)
# Model
@@ -36,7 +36,7 @@ def str2bool(v):
parser.add_argument('--test_sample_step', type=int, default=100)
parser.add_argument('--output_save_step', type=int, default=1000)
# learning
- parser.add_argument('--max_sample', type=int, default=5000,
+ parser.add_argument('--max_sample', type=int, default=5000,
help='num of samples the model can see')
parser.add_argument('--max_training_steps', type=int, default=10000000)
parser.add_argument('--learning_rate_g', type=float, default=1e-4)
@@ -51,16 +51,17 @@ def str2bool(v):
config = parser.parse_args()
- dataset_path = os.path.join('./datasets', config.dataset.lower())
+ dataset_path = os.path.join(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets",
+ config.dataset.lower())
dataset_train, dataset_test = dataset.create_default_splits(dataset_path)
-
+ print("step2")
img, label = dataset_train.get_data(dataset_train.ids[0])
+ print("step3")
config.h = img.shape[0]
config.w = img.shape[1]
config.c = img.shape[2]
- config.num_class = label.shape[0]
+ config.num_class = label.shape[0]
# --- create model ---
model = Model(config, debug_information=config.debug, is_train=is_train)
-
return config, model, dataset_train, dataset_test
diff --git a/download.py b/download.py
index 90c702e..e8ac1e8 100644
--- a/download.py
+++ b/download.py
@@ -11,7 +11,6 @@
def prepare_h5py(train_image, train_label, test_image, test_label, data_dir, shape=None):
-
image = np.concatenate((train_image, test_image), axis=0).astype(np.uint8)
label = np.concatenate((train_label, test_label), axis=0).astype(np.uint8)
@@ -24,35 +23,37 @@ def prepare_h5py(train_image, train_label, test_image, test_label, data_dir, sha
bar.start()
f = h5py.File(os.path.join(data_dir, 'data.hdf5'), 'w')
- data_id = open(os.path.join(data_dir,'id.txt'), 'w')
+ data_id = open(os.path.join(data_dir, 'id.txt'), 'w')
for i in range(image.shape[0]):
- if i%(image.shape[0]/100)==0:
- bar.update(i/(image.shape[0]/100))
+ if i % (image.shape[0] / 100) == 0:
+ bar.update(i / (image.shape[0] / 100))
grp = f.create_group(str(i))
- data_id.write(str(i)+'\n')
+ data_id.write(str(i) + '\n')
if shape:
grp['image'] = np.reshape(image[i], shape, order='F')
else:
grp['image'] = image[i]
label_vec = np.zeros(10)
- label_vec[label[i]%10] = 1
+ label_vec[label[i] % 10] = 1
grp['label'] = label_vec.astype(np.bool)
bar.finish()
f.close()
data_id.close()
return
+
def check_file(data_dir):
if os.path.exists(data_dir):
if os.path.isfile(os.path.join('data.hdf5')) and \
- os.path.isfile(os.path.join('id.txt')):
+ os.path.isfile(os.path.join('id.txt')):
return True
else:
os.mkdir(data_dir)
return False
+
def download_mnist(download_path):
data_dir = os.path.join(download_path, 'mnist')
@@ -62,42 +63,46 @@ def download_mnist(download_path):
data_url = 'http://yann.lecun.com/exdb/mnist/'
keys = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
- 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
-
- for k in keys:
- url = (data_url+k).format(**locals())
- target_path = os.path.join(data_dir, k)
- cmd = ['curl', url, '-o', target_path]
- print('Downloading ', k)
- subprocess.call(cmd)
- cmd = ['gzip', '-d', target_path]
- print('Unzip ', k)
- subprocess.call(cmd)
-
+ 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
+ # keys = [ 'train-labels-idx1-ubyte.gz']
+
+ # for k in keys:
+ # url = (data_url+k).format(**locals())
+ # target_path = os.path.join(data_dir, k)
+ # print(target_path)
+ # cmd = ['curl', url, '-o', target_path]
+ # print('Downloading ', k)
+ # subprocess.call(cmd)
+ # # cmd = ['gzip', '-d', target_path]
+ # cmd = ['7z', 'e', target_path]
+ # print('Unzip ', k)
+ # # subprocess.call(cmd)
+ # print('OK ', k)
num_mnist_train = 60000
num_mnist_test = 10000
- fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
- train_image = loaded[16:].reshape((num_mnist_train,28,28,1)).astype(np.float)
+ fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte'))
+ loaded = np.fromfile(file=fd, dtype=np.uint8)
+ train_image = loaded[16:].reshape((num_mnist_train, 28, 28, 1)).astype(np.float)
- fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
+ fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte'))
+ loaded = np.fromfile(file=fd, dtype=np.uint8)
train_label = np.asarray(loaded[8:].reshape((num_mnist_train)).astype(np.float))
- fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
- test_image = loaded[16:].reshape((num_mnist_test,28,28,1)).astype(np.float)
+ fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte'))
+ loaded = np.fromfile(file=fd, dtype=np.uint8)
+ test_image = loaded[16:].reshape((num_mnist_test, 28, 28, 1)).astype(np.float)
- fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
+ fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte'))
+ loaded = np.fromfile(file=fd, dtype=np.uint8)
test_label = np.asarray(loaded[8:].reshape((num_mnist_test)).astype(np.float))
prepare_h5py(train_image, train_label, test_image, test_label, data_dir)
- for k in keys:
- cmd = ['rm', '-f', os.path.join(data_dir, k[:-3])]
- subprocess.call(cmd)
+ # for k in keys:
+ # cmd = ['rm', '-f', os.path.join(data_dir, k[:-3])]
+ # subprocess.call(cmd)
+
def download_svhn(download_path):
data_dir = os.path.join(download_path, 'svhn')
@@ -123,31 +128,32 @@ def svhn_loader(url, path):
prepare_h5py(np.transpose(train_image, (3, 0, 1, 2)), train_label,
np.transpose(test_image, (3, 0, 1, 2)), test_label, data_dir)
- cmd = ['rm', '-f', os.path.join(data_dir, '*.mat')]
- subprocess.call(cmd)
+ # cmd = ['rm', '-f', os.path.join(data_dir, '*.mat')]
+ # subprocess.call(cmd)
+
def download_cifar10(download_path):
data_dir = os.path.join(download_path, 'cifar10')
# cifar file loader
def unpickle(file):
- import cPickle
+ import pickle
with open(file, 'rb') as fo:
- dict = cPickle.load(fo)
+ dict = pickle.load(fo, encoding='bytes')
return dict
if check_file(data_dir):
print('CIFAR was downloaded.')
return
- data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
- k = 'cifar-10-python.tar.gz'
- target_path = os.path.join(data_dir, k)
- print(target_path)
- cmd = ['curl', data_url, '-o', target_path]
- print('Downloading CIFAR10')
- subprocess.call(cmd)
- tarfile.open(target_path, 'r:gz').extractall(data_dir)
+ # data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
+ # k = 'cifar-10-python.tar.gz'
+ # target_path = os.path.join(data_dir, k)
+ # print(target_path)
+ # cmd = ['curl', data_url, '-o', target_path]
+ # print('Downloading CIFAR10')
+ # subprocess.call(cmd)
+ # tarfile.open(target_path, 'r:gz').extractall(data_dir)
num_cifar_train = 50000
num_cifar_test = 10000
@@ -156,34 +162,60 @@ def unpickle(file):
train_image = []
train_label = []
for i in range(5):
- fd = os.path.join(target_path, 'data_batch_'+str(i+1))
+ fd = os.path.join(target_path, 'data_batch_' + str(i + 1))
dict = unpickle(fd)
- train_image.append(dict['data'])
- train_label.append(dict['labels'])
- train_image = np.reshape(np.stack(train_image, axis=0), [num_cifar_train, 32*32*3])
+ sys.stdout = Logger(r'C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets\cifar10\a.txt')
+ print(dict)
+ print('------------------')
+
+ train_image.append(dict[b'data'])
+ train_label.append(dict[b'labels'])
+
+
+ train_image = np.reshape(np.stack(train_image, axis=0), [num_cifar_train, 32 * 32 * 3])
train_label = np.reshape(np.array(np.stack(train_label, axis=0)), [num_cifar_train])
fd = os.path.join(target_path, 'test_batch')
dict = unpickle(fd)
- test_image = np.reshape(dict['data'], [num_cifar_test, 32*32*3])
- test_label = np.reshape(dict['labels'], [num_cifar_test])
+ test_image = np.reshape(dict[b'data'], [num_cifar_test, 32 * 32 * 3])
+ test_label = np.reshape(dict[b'labels'], [num_cifar_test])
prepare_h5py(train_image, train_label, test_image, test_label, data_dir, [32, 32, 3])
- cmd = ['rm', '-f', os.path.join(data_dir, 'cifar-10-python.tar.gz')]
- subprocess.call(cmd)
- cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')]
- subprocess.call(cmd)
+ # cmd = ['rm', '-f', os.path.join(data_dir, 'cifar-10-python.tar.gz')]
+ # subprocess.call(cmd)
+ # cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')]
+ # subprocess.call(cmd)
+
+
+import sys
+import os
+
+
+class Logger(object):
+ def __init__(self, filename="Default.log"):
+ self.terminal = sys.stdout
+ self.log = open(filename, "a")
+
+ def write(self, message):
+ self.terminal.write(message)
+ self.log.write(message)
+
+ def flush(self):
+ pass
+
+
+
if __name__ == '__main__':
args = parser.parse_args()
- path = './datasets'
+ path = r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets"
if not os.path.exists(path): os.mkdir(path)
if 'MNIST' in args.datasets:
- download_mnist('./datasets')
+ download_mnist(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets")
if 'SVHN' in args.datasets:
- download_svhn('./datasets')
+ download_svhn(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets")
if 'CIFAR10' in args.datasets:
- download_cifar10('./datasets')
+ download_cifar10(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets")
diff --git a/trainer.py b/trainer.py
index 3186400..779bdd1 100644
--- a/trainer.py
+++ b/trainer.py
@@ -33,7 +33,7 @@ def __init__(self, config, model, dataset, dataset_test):
os.makedirs(self.train_dir)
log.infov("Train Dir: %s", self.train_dir)
-
+ print("step2")
# --- input ops ---
self.batch_size = config.batch_size
@@ -205,9 +205,9 @@ def log_step_message(self, step, accuracy, d_loss, g_loss,
)
def main():
-
+ print("step0")
config, model, dataset_train, dataset_test = argparser(is_train=True)
-
+ print("step1")
trainer = Trainer(config, model, dataset_train, dataset_test)
log.warning("dataset: %s, learning_rate_g: %f, learning_rate_d: %f",
diff --git a/visualize_training.py b/visualize_training.py
index 88e0ad8..ced515d 100644
--- a/visualize_training.py
+++ b/visualize_training.py
@@ -27,5 +27,5 @@
II.append(I)
II = np.stack(II)
-print II.shape
+print(II.shape)
imageio.mimsave(args.output_file, II, fps=5)
From e1bb513dee59f3b7e242d7bfc49d4e437098054d Mon Sep 17 00:00:00 2001
From: zhangyicheng
Date: Tue, 12 Feb 2019 23:39:34 +0800
Subject: [PATCH 2/5] path name change to ubantu
---
config.py | 2 +-
download.py | 10 +++++-----
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/config.py b/config.py
index 234d99a..a37d35e 100644
--- a/config.py
+++ b/config.py
@@ -51,7 +51,7 @@ def str2bool(v):
config = parser.parse_args()
- dataset_path = os.path.join(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets",
+ dataset_path = os.path.join(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets",
config.dataset.lower())
dataset_train, dataset_test = dataset.create_default_splits(dataset_path)
print("step2")
diff --git a/download.py b/download.py
index e8ac1e8..02b4cfb 100644
--- a/download.py
+++ b/download.py
@@ -165,7 +165,7 @@ def unpickle(file):
fd = os.path.join(target_path, 'data_batch_' + str(i + 1))
dict = unpickle(fd)
- sys.stdout = Logger(r'C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets\cifar10\a.txt')
+ sys.stdout = Logger(r'/home/zhou/Project/tf/SSGAN-Tensorflow/a.txt')
print(dict)
print('------------------')
@@ -210,12 +210,12 @@ def flush(self):
if __name__ == '__main__':
args = parser.parse_args()
- path = r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets"
+ path = r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets"
if not os.path.exists(path): os.mkdir(path)
if 'MNIST' in args.datasets:
- download_mnist(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets")
+ download_mnist(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets")
if 'SVHN' in args.datasets:
- download_svhn(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets")
+ download_svhn(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets")
if 'CIFAR10' in args.datasets:
- download_cifar10(r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets")
+ download_cifar10(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets")
From d8dd4a585faffe8cffe18e8dfb38b164b9cbe642 Mon Sep 17 00:00:00 2001
From: zhangyicheng1986
Date: Sun, 24 Feb 2019 19:16:44 +0800
Subject: [PATCH 3/5] inint aaa
---
.idea/SSGAN-Tensorflow.iml | 12 ++++++++++++
.idea/misc.xml | 4 ++++
.idea/modules.xml | 8 ++++++++
aaaa.py | 1 +
4 files changed, 25 insertions(+)
create mode 100644 .idea/SSGAN-Tensorflow.iml
create mode 100644 .idea/misc.xml
create mode 100644 .idea/modules.xml
create mode 100644 aaaa.py
diff --git a/.idea/SSGAN-Tensorflow.iml b/.idea/SSGAN-Tensorflow.iml
new file mode 100644
index 0000000..6f63a63
--- /dev/null
+++ b/.idea/SSGAN-Tensorflow.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..12f37ed
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..0dd1cf6
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/aaaa.py b/aaaa.py
new file mode 100644
index 0000000..086876d
--- /dev/null
+++ b/aaaa.py
@@ -0,0 +1 @@
+hggk
\ No newline at end of file
From f5f71fe9f893ddedc98c2e911d11bb1d6187cb6a Mon Sep 17 00:00:00 2001
From: zhangyicheng1986
Date: Sun, 24 Feb 2019 19:18:37 +0800
Subject: [PATCH 4/5] inint aaa
---
aaaa.py | 1 -
1 file changed, 1 deletion(-)
delete mode 100644 aaaa.py
diff --git a/aaaa.py b/aaaa.py
deleted file mode 100644
index 086876d..0000000
--- a/aaaa.py
+++ /dev/null
@@ -1 +0,0 @@
-hggk
\ No newline at end of file
From 5fa3faa1328ae8d9e0fc503e5b8e1dd9f7fac2e3 Mon Sep 17 00:00:00 2001
From: zhangyicheng1986
Date: Tue, 26 Feb 2019 00:42:06 +0800
Subject: [PATCH 5/5] for ubuntu 16.04 and python2.7
---
.idea/SSGAN-Tensorflow.iml | 2 +-
.idea/misc.xml | 2 +-
Cifar10DataReader.py | 2 +-
config.py | 2 +-
download.py | 38 +++++++++++++++++++-------------------
5 files changed, 23 insertions(+), 23 deletions(-)
diff --git a/.idea/SSGAN-Tensorflow.iml b/.idea/SSGAN-Tensorflow.iml
index 6f63a63..f3dcd9e 100644
--- a/.idea/SSGAN-Tensorflow.iml
+++ b/.idea/SSGAN-Tensorflow.iml
@@ -2,7 +2,7 @@
-
+
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 12f37ed..1b55951 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/Cifar10DataReader.py b/Cifar10DataReader.py
index d88dbe8..0eeaca7 100644
--- a/Cifar10DataReader.py
+++ b/Cifar10DataReader.py
@@ -79,7 +79,7 @@ def next_test_data(self, batch_size=100):
if __name__ == "__main__":
- dr = Cifar10DataReader(cifar_folder=r"C:\Users\Administrator\PycharmProjects\SSGAN-Tensorflow\datasets\cifar10\cifar-10-batches-py")
+ dr = Cifar10DataReader(cifar_folder=r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets/cifar10/cifar-10-batches-py")
import matplotlib.pyplot as plt
d, l = dr.next_test_data()
diff --git a/config.py b/config.py
index a37d35e..eb73ebd 100644
--- a/config.py
+++ b/config.py
@@ -51,7 +51,7 @@ def str2bool(v):
config = parser.parse_args()
- dataset_path = os.path.join(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets",
+ dataset_path = os.path.join(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets",
config.dataset.lower())
dataset_train, dataset_test = dataset.create_default_splits(dataset_path)
print("step2")
diff --git a/download.py b/download.py
index 02b4cfb..a969f73 100644
--- a/download.py
+++ b/download.py
@@ -137,23 +137,23 @@ def download_cifar10(download_path):
# cifar file loader
def unpickle(file):
- import pickle
+ import cPickle
with open(file, 'rb') as fo:
- dict = pickle.load(fo, encoding='bytes')
+ dict = cPickle.load(fo)#dict = cPickle.load(fo, encoding='bytes')
return dict
if check_file(data_dir):
print('CIFAR was downloaded.')
return
- # data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
- # k = 'cifar-10-python.tar.gz'
- # target_path = os.path.join(data_dir, k)
- # print(target_path)
- # cmd = ['curl', data_url, '-o', target_path]
- # print('Downloading CIFAR10')
- # subprocess.call(cmd)
- # tarfile.open(target_path, 'r:gz').extractall(data_dir)
+ data_url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
+ k = 'cifar-10-python.tar.gz'
+ target_path = os.path.join(data_dir, k)
+ print(target_path)
+ #cmd = ['curl', data_url, '-o', target_path]
+ print('Downloading CIFAR10')
+ #subprocess.call(cmd)
+ tarfile.open(target_path, 'r:gz').extractall(data_dir)
num_cifar_train = 50000
num_cifar_test = 10000
@@ -165,7 +165,7 @@ def unpickle(file):
fd = os.path.join(target_path, 'data_batch_' + str(i + 1))
dict = unpickle(fd)
- sys.stdout = Logger(r'/home/zhou/Project/tf/SSGAN-Tensorflow/a.txt')
+ sys.stdout = Logger(r'/home/yc/PycharmProjects/SSGAN-Tensorflow/a.txt')
print(dict)
print('------------------')
@@ -183,10 +183,10 @@ def unpickle(file):
prepare_h5py(train_image, train_label, test_image, test_label, data_dir, [32, 32, 3])
- # cmd = ['rm', '-f', os.path.join(data_dir, 'cifar-10-python.tar.gz')]
- # subprocess.call(cmd)
- # cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')]
- # subprocess.call(cmd)
+ cmd = ['rm', '-f', os.path.join(data_dir, 'cifar-10-python.tar.gz')]
+ subprocess.call(cmd)
+ cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')]
+ subprocess.call(cmd)
import sys
@@ -210,12 +210,12 @@ def flush(self):
if __name__ == '__main__':
args = parser.parse_args()
- path = r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets"
+ path = r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets"
if not os.path.exists(path): os.mkdir(path)
if 'MNIST' in args.datasets:
- download_mnist(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets")
+ download_mnist(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets")
if 'SVHN' in args.datasets:
- download_svhn(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets")
+ download_svhn(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets")
if 'CIFAR10' in args.datasets:
- download_cifar10(r"/home/zhou/Project/tf/SSGAN-Tensorflow/datasets")
+ download_cifar10(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets")