diff --git a/.idea/SSGAN-Tensorflow.iml b/.idea/SSGAN-Tensorflow.iml
new file mode 100644
index 0000000..f3dcd9e
--- /dev/null
+++ b/.idea/SSGAN-Tensorflow.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..1b55951
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..0dd1cf6
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/other.xml b/.idea/other.xml
new file mode 100644
index 0000000..640fd80
--- /dev/null
+++ b/.idea/other.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/Cifar10DataReader.py b/Cifar10DataReader.py
new file mode 100644
index 0000000..0eeaca7
--- /dev/null
+++ b/Cifar10DataReader.py
@@ -0,0 +1,91 @@
+import pickle
+import numpy as np
+import os
+
+
+class Cifar10DataReader():
+ def __init__(self, cifar_folder, onehot=True):
+ self.cifar_folder = cifar_folder
+ self.onehot = onehot
+ self.data_index = 1
+ self.read_next = True
+ self.data_label_train = None
+ self.data_label_test = None
+ self.batch_index = 0
+
+ def unpickle(self, f):
+ fo = open(f, 'rb')
+ d = pickle.load(fo, encoding='bytes')
+ fo.close()
+ return d
+
+ def next_train_data(self, batch_size=100):
+ assert 10000 % batch_size == 0, "10000%batch_size!=0"
+ rdata = None
+ rlabel = None
+ if self.read_next:
+ f = os.path.join(self.cifar_folder, "data_batch_%s" % (self.data_index))
+ print('read: %s' % f)
+ dic_train = self.unpickle(f)
+ self.data_label_train = list(zip(dic_train[b'data'], dic_train[b'labels'])) # label 0~9
+ np.random.shuffle(self.data_label_train)
+
+ self.read_next = False
+ if self.data_index == 5:
+ self.data_index = 1
+ else:
+ self.data_index += 1
+
+ if self.batch_index < len(self.data_label_train) // batch_size:
+ # print self.batch_index
+ datum = self.data_label_train[self.batch_index * batch_size:(self.batch_index + 1) * batch_size]
+ self.batch_index += 1
+ rdata, rlabel = self._decode(datum, self.onehot)
+ else:
+ self.batch_index = 0
+ self.read_next = True
+ return self.next_train_data(batch_size=batch_size)
+
+ return rdata, rlabel
+
+ def _decode(self, datum, onehot):
+ rdata = list();
+ rlabel = list()
+ if onehot:
+ for d, l in datum:
+ rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
+ hot = np.zeros(10)
+ hot[int(l)] = 1
+ rlabel.append(hot)
+ else:
+ for d, l in datum:
+ rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
+ rlabel.append(int(l))
+ return rdata, rlabel
+
+ def next_test_data(self, batch_size=100):
+ if self.data_label_test is None:
+ f = os.path.join(self.cifar_folder, "test_batch")
+ print('read: %s' % f)
+ dic_test = self.unpickle(f)
+ data = dic_test[b'data']
+ labels = dic_test[b'labels'] # 0~9
+ self.data_label_test = list(zip(data, labels))
+
+ np.random.shuffle(self.data_label_test)
+ datum = self.data_label_test[0:batch_size]
+
+ return self._decode(datum, self.onehot)
+
+
+if __name__ == "__main__":
+ dr = Cifar10DataReader(cifar_folder=r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets/cifar10/cifar-10-batches-py")
+
+ import matplotlib.pyplot as plt
+ d, l = dr.next_test_data()
+ print(np.shape(d), np.shape(l))
+ plt.imshow(d[0])
+ plt.show()
+ for i in range(600):
+ d, l = dr.next_train_data(batch_size=100)
+ print(np.shape(d), np.shape(l))
diff --git a/config.py b/config.py
index fd87b3c..eb73ebd 100644
--- a/config.py
+++ b/config.py
@@ -6,9 +6,9 @@
def argparser(is_train=True):
-
def str2bool(v):
return v.lower() == 'true'
+
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -16,7 +16,7 @@ def str2bool(v):
parser.add_argument('--prefix', type=str, default='default')
parser.add_argument('--train_dir', type=str)
parser.add_argument('--checkpoint', type=str, default=None)
- parser.add_argument('--dataset', type=str, default='CIFAR10',
+ parser.add_argument('--dataset', type=str, default='CIFAR10',
choices=['MNIST', 'SVHN', 'CIFAR10'])
parser.add_argument('--dump_result', type=str2bool, default=False)
# Model
@@ -36,7 +36,7 @@ def str2bool(v):
parser.add_argument('--test_sample_step', type=int, default=100)
parser.add_argument('--output_save_step', type=int, default=1000)
# learning
- parser.add_argument('--max_sample', type=int, default=5000,
+ parser.add_argument('--max_sample', type=int, default=5000,
help='num of samples the model can see')
parser.add_argument('--max_training_steps', type=int, default=10000000)
parser.add_argument('--learning_rate_g', type=float, default=1e-4)
@@ -51,16 +51,17 @@ def str2bool(v):
config = parser.parse_args()
- dataset_path = os.path.join('./datasets', config.dataset.lower())
+ dataset_path = os.path.join(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets",
+ config.dataset.lower())
dataset_train, dataset_test = dataset.create_default_splits(dataset_path)
-
+ print("step2")
img, label = dataset_train.get_data(dataset_train.ids[0])
+ print("step3")
config.h = img.shape[0]
config.w = img.shape[1]
config.c = img.shape[2]
- config.num_class = label.shape[0]
+ config.num_class = label.shape[0]
# --- create model ---
model = Model(config, debug_information=config.debug, is_train=is_train)
-
return config, model, dataset_train, dataset_test
diff --git a/download.py b/download.py
index 90c702e..a969f73 100644
--- a/download.py
+++ b/download.py
@@ -11,7 +11,6 @@
def prepare_h5py(train_image, train_label, test_image, test_label, data_dir, shape=None):
-
image = np.concatenate((train_image, test_image), axis=0).astype(np.uint8)
label = np.concatenate((train_label, test_label), axis=0).astype(np.uint8)
@@ -24,35 +23,37 @@ def prepare_h5py(train_image, train_label, test_image, test_label, data_dir, sha
bar.start()
f = h5py.File(os.path.join(data_dir, 'data.hdf5'), 'w')
- data_id = open(os.path.join(data_dir,'id.txt'), 'w')
+ data_id = open(os.path.join(data_dir, 'id.txt'), 'w')
for i in range(image.shape[0]):
- if i%(image.shape[0]/100)==0:
- bar.update(i/(image.shape[0]/100))
+ if i % (image.shape[0] / 100) == 0:
+ bar.update(i / (image.shape[0] / 100))
grp = f.create_group(str(i))
- data_id.write(str(i)+'\n')
+ data_id.write(str(i) + '\n')
if shape:
grp['image'] = np.reshape(image[i], shape, order='F')
else:
grp['image'] = image[i]
label_vec = np.zeros(10)
- label_vec[label[i]%10] = 1
+ label_vec[label[i] % 10] = 1
grp['label'] = label_vec.astype(np.bool)
bar.finish()
f.close()
data_id.close()
return
+
def check_file(data_dir):
if os.path.exists(data_dir):
if os.path.isfile(os.path.join('data.hdf5')) and \
- os.path.isfile(os.path.join('id.txt')):
+ os.path.isfile(os.path.join('id.txt')):
return True
else:
os.mkdir(data_dir)
return False
+
def download_mnist(download_path):
data_dir = os.path.join(download_path, 'mnist')
@@ -62,42 +63,46 @@ def download_mnist(download_path):
data_url = 'http://yann.lecun.com/exdb/mnist/'
keys = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
- 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
-
- for k in keys:
- url = (data_url+k).format(**locals())
- target_path = os.path.join(data_dir, k)
- cmd = ['curl', url, '-o', target_path]
- print('Downloading ', k)
- subprocess.call(cmd)
- cmd = ['gzip', '-d', target_path]
- print('Unzip ', k)
- subprocess.call(cmd)
-
+ 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
+ # keys = [ 'train-labels-idx1-ubyte.gz']
+
+ # for k in keys:
+ # url = (data_url+k).format(**locals())
+ # target_path = os.path.join(data_dir, k)
+ # print(target_path)
+ # cmd = ['curl', url, '-o', target_path]
+ # print('Downloading ', k)
+ # subprocess.call(cmd)
+ # # cmd = ['gzip', '-d', target_path]
+ # cmd = ['7z', 'e', target_path]
+ # print('Unzip ', k)
+ # # subprocess.call(cmd)
+ # print('OK ', k)
num_mnist_train = 60000
num_mnist_test = 10000
- fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
- train_image = loaded[16:].reshape((num_mnist_train,28,28,1)).astype(np.float)
+ fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte'))
+ loaded = np.fromfile(file=fd, dtype=np.uint8)
+ train_image = loaded[16:].reshape((num_mnist_train, 28, 28, 1)).astype(np.float)
- fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
+ fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte'))
+ loaded = np.fromfile(file=fd, dtype=np.uint8)
train_label = np.asarray(loaded[8:].reshape((num_mnist_train)).astype(np.float))
- fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
- test_image = loaded[16:].reshape((num_mnist_test,28,28,1)).astype(np.float)
+ fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte'))
+ loaded = np.fromfile(file=fd, dtype=np.uint8)
+ test_image = loaded[16:].reshape((num_mnist_test, 28, 28, 1)).astype(np.float)
- fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
- loaded = np.fromfile(file=fd,dtype=np.uint8)
+ fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte'))
+ loaded = np.fromfile(file=fd, dtype=np.uint8)
test_label = np.asarray(loaded[8:].reshape((num_mnist_test)).astype(np.float))
prepare_h5py(train_image, train_label, test_image, test_label, data_dir)
- for k in keys:
- cmd = ['rm', '-f', os.path.join(data_dir, k[:-3])]
- subprocess.call(cmd)
+ # for k in keys:
+ # cmd = ['rm', '-f', os.path.join(data_dir, k[:-3])]
+ # subprocess.call(cmd)
+
def download_svhn(download_path):
data_dir = os.path.join(download_path, 'svhn')
@@ -123,8 +128,9 @@ def svhn_loader(url, path):
prepare_h5py(np.transpose(train_image, (3, 0, 1, 2)), train_label,
np.transpose(test_image, (3, 0, 1, 2)), test_label, data_dir)
- cmd = ['rm', '-f', os.path.join(data_dir, '*.mat')]
- subprocess.call(cmd)
+ # cmd = ['rm', '-f', os.path.join(data_dir, '*.mat')]
+ # subprocess.call(cmd)
+
def download_cifar10(download_path):
data_dir = os.path.join(download_path, 'cifar10')
@@ -133,7 +139,7 @@ def download_cifar10(download_path):
def unpickle(file):
import cPickle
with open(file, 'rb') as fo:
- dict = cPickle.load(fo)
+ dict = cPickle.load(fo)#dict = cPickle.load(fo, encoding='bytes')
return dict
if check_file(data_dir):
@@ -144,9 +150,9 @@ def unpickle(file):
k = 'cifar-10-python.tar.gz'
target_path = os.path.join(data_dir, k)
print(target_path)
- cmd = ['curl', data_url, '-o', target_path]
+ #cmd = ['curl', data_url, '-o', target_path]
print('Downloading CIFAR10')
- subprocess.call(cmd)
+ #subprocess.call(cmd)
tarfile.open(target_path, 'r:gz').extractall(data_dir)
num_cifar_train = 50000
@@ -156,18 +162,24 @@ def unpickle(file):
train_image = []
train_label = []
for i in range(5):
- fd = os.path.join(target_path, 'data_batch_'+str(i+1))
+ fd = os.path.join(target_path, 'data_batch_' + str(i + 1))
dict = unpickle(fd)
- train_image.append(dict['data'])
- train_label.append(dict['labels'])
- train_image = np.reshape(np.stack(train_image, axis=0), [num_cifar_train, 32*32*3])
+ sys.stdout = Logger(r'/home/yc/PycharmProjects/SSGAN-Tensorflow/a.txt')
+ print(dict)
+ print('------------------')
+
+ train_image.append(dict[b'data'])
+ train_label.append(dict[b'labels'])
+
+
+ train_image = np.reshape(np.stack(train_image, axis=0), [num_cifar_train, 32 * 32 * 3])
train_label = np.reshape(np.array(np.stack(train_label, axis=0)), [num_cifar_train])
fd = os.path.join(target_path, 'test_batch')
dict = unpickle(fd)
- test_image = np.reshape(dict['data'], [num_cifar_test, 32*32*3])
- test_label = np.reshape(dict['labels'], [num_cifar_test])
+ test_image = np.reshape(dict[b'data'], [num_cifar_test, 32 * 32 * 3])
+ test_label = np.reshape(dict[b'labels'], [num_cifar_test])
prepare_h5py(train_image, train_label, test_image, test_label, data_dir, [32, 32, 3])
@@ -176,14 +188,34 @@ def unpickle(file):
cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')]
subprocess.call(cmd)
+
+import sys
+import os
+
+
+class Logger(object):
+ def __init__(self, filename="Default.log"):
+ self.terminal = sys.stdout
+ self.log = open(filename, "a")
+
+ def write(self, message):
+ self.terminal.write(message)
+ self.log.write(message)
+
+ def flush(self):
+ pass
+
+
+
+
if __name__ == '__main__':
args = parser.parse_args()
- path = './datasets'
+ path = r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets"
if not os.path.exists(path): os.mkdir(path)
if 'MNIST' in args.datasets:
- download_mnist('./datasets')
+ download_mnist(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets")
if 'SVHN' in args.datasets:
- download_svhn('./datasets')
+ download_svhn(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets")
if 'CIFAR10' in args.datasets:
- download_cifar10('./datasets')
+ download_cifar10(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets")
diff --git a/trainer.py b/trainer.py
index 3186400..779bdd1 100644
--- a/trainer.py
+++ b/trainer.py
@@ -33,7 +33,7 @@ def __init__(self, config, model, dataset, dataset_test):
os.makedirs(self.train_dir)
log.infov("Train Dir: %s", self.train_dir)
-
+ print("step2")
# --- input ops ---
self.batch_size = config.batch_size
@@ -205,9 +205,9 @@ def log_step_message(self, step, accuracy, d_loss, g_loss,
)
def main():
-
+ print("step0")
config, model, dataset_train, dataset_test = argparser(is_train=True)
-
+ print("step1")
trainer = Trainer(config, model, dataset_train, dataset_test)
log.warning("dataset: %s, learning_rate_g: %f, learning_rate_d: %f",
diff --git a/visualize_training.py b/visualize_training.py
index 88e0ad8..ced515d 100644
--- a/visualize_training.py
+++ b/visualize_training.py
@@ -27,5 +27,5 @@
II.append(I)
II = np.stack(II)
-print II.shape
+print(II.shape)
imageio.mimsave(args.output_file, II, fps=5)