diff --git a/.idea/SSGAN-Tensorflow.iml b/.idea/SSGAN-Tensorflow.iml new file mode 100644 index 0000000..f3dcd9e --- /dev/null +++ b/.idea/SSGAN-Tensorflow.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..1b55951 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..0dd1cf6 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/other.xml b/.idea/other.xml new file mode 100644 index 0000000..640fd80 --- /dev/null +++ b/.idea/other.xml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Cifar10DataReader.py b/Cifar10DataReader.py new file mode 100644 index 0000000..0eeaca7 --- /dev/null +++ b/Cifar10DataReader.py @@ -0,0 +1,91 @@ +import pickle +import numpy as np +import os + + +class Cifar10DataReader(): + def __init__(self, cifar_folder, onehot=True): + self.cifar_folder = cifar_folder + self.onehot = onehot + self.data_index = 1 + self.read_next = True + self.data_label_train = None + self.data_label_test = None + self.batch_index = 0 + + def unpickle(self, f): + fo = open(f, 'rb') + d = pickle.load(fo, encoding='bytes') + fo.close() + return d + + def next_train_data(self, batch_size=100): + assert 10000 % batch_size == 0, "10000%batch_size!=0" + rdata = None + rlabel = None + if self.read_next: + f = os.path.join(self.cifar_folder, "data_batch_%s" % (self.data_index)) + print('read: %s' % f) + dic_train = self.unpickle(f) + self.data_label_train = list(zip(dic_train[b'data'], dic_train[b'labels'])) # label 0~9 + np.random.shuffle(self.data_label_train) + + self.read_next = False + if self.data_index == 5: + self.data_index = 1 + else: + self.data_index += 1 + + if self.batch_index < len(self.data_label_train) // batch_size: + # print self.batch_index + datum = self.data_label_train[self.batch_index * batch_size:(self.batch_index + 1) * batch_size] + self.batch_index += 1 + rdata, rlabel = self._decode(datum, self.onehot) + else: + self.batch_index = 0 + self.read_next = True + return self.next_train_data(batch_size=batch_size) + + return rdata, rlabel + + def _decode(self, datum, onehot): + rdata = list(); + rlabel = list() + if onehot: + for d, l in datum: + rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3])) + hot = np.zeros(10) + hot[int(l)] = 1 + rlabel.append(hot) + else: + for d, l in datum: + rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3])) + rlabel.append(int(l)) + return rdata, rlabel + + def next_test_data(self, batch_size=100): + if self.data_label_test is None: + f = os.path.join(self.cifar_folder, "test_batch") + print('read: %s' % f) + dic_test = self.unpickle(f) + data = dic_test[b'data'] + labels = dic_test[b'labels'] # 0~9 + self.data_label_test = list(zip(data, labels)) + + np.random.shuffle(self.data_label_test) + datum = self.data_label_test[0:batch_size] + + return self._decode(datum, self.onehot) + + +if __name__ == "__main__": + dr = Cifar10DataReader(cifar_folder=r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets/cifar10/cifar-10-batches-py") + + import matplotlib.pyplot as plt + d, l = dr.next_test_data() + print(np.shape(d), np.shape(l)) + plt.imshow(d[0]) + plt.show() + for i in range(600): + d, l = dr.next_train_data(batch_size=100) + print(np.shape(d), np.shape(l)) diff --git a/config.py b/config.py index fd87b3c..eb73ebd 100644 --- a/config.py +++ b/config.py @@ -6,9 +6,9 @@ def argparser(is_train=True): - def str2bool(v): return v.lower() == 'true' + parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -16,7 +16,7 @@ def str2bool(v): parser.add_argument('--prefix', type=str, default='default') parser.add_argument('--train_dir', type=str) parser.add_argument('--checkpoint', type=str, default=None) - parser.add_argument('--dataset', type=str, default='CIFAR10', + parser.add_argument('--dataset', type=str, default='CIFAR10', choices=['MNIST', 'SVHN', 'CIFAR10']) parser.add_argument('--dump_result', type=str2bool, default=False) # Model @@ -36,7 +36,7 @@ def str2bool(v): parser.add_argument('--test_sample_step', type=int, default=100) parser.add_argument('--output_save_step', type=int, default=1000) # learning - parser.add_argument('--max_sample', type=int, default=5000, + parser.add_argument('--max_sample', type=int, default=5000, help='num of samples the model can see') parser.add_argument('--max_training_steps', type=int, default=10000000) parser.add_argument('--learning_rate_g', type=float, default=1e-4) @@ -51,16 +51,17 @@ def str2bool(v): config = parser.parse_args() - dataset_path = os.path.join('./datasets', config.dataset.lower()) + dataset_path = os.path.join(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets", + config.dataset.lower()) dataset_train, dataset_test = dataset.create_default_splits(dataset_path) - + print("step2") img, label = dataset_train.get_data(dataset_train.ids[0]) + print("step3") config.h = img.shape[0] config.w = img.shape[1] config.c = img.shape[2] - config.num_class = label.shape[0] + config.num_class = label.shape[0] # --- create model --- model = Model(config, debug_information=config.debug, is_train=is_train) - return config, model, dataset_train, dataset_test diff --git a/download.py b/download.py index 90c702e..a969f73 100644 --- a/download.py +++ b/download.py @@ -11,7 +11,6 @@ def prepare_h5py(train_image, train_label, test_image, test_label, data_dir, shape=None): - image = np.concatenate((train_image, test_image), axis=0).astype(np.uint8) label = np.concatenate((train_label, test_label), axis=0).astype(np.uint8) @@ -24,35 +23,37 @@ def prepare_h5py(train_image, train_label, test_image, test_label, data_dir, sha bar.start() f = h5py.File(os.path.join(data_dir, 'data.hdf5'), 'w') - data_id = open(os.path.join(data_dir,'id.txt'), 'w') + data_id = open(os.path.join(data_dir, 'id.txt'), 'w') for i in range(image.shape[0]): - if i%(image.shape[0]/100)==0: - bar.update(i/(image.shape[0]/100)) + if i % (image.shape[0] / 100) == 0: + bar.update(i / (image.shape[0] / 100)) grp = f.create_group(str(i)) - data_id.write(str(i)+'\n') + data_id.write(str(i) + '\n') if shape: grp['image'] = np.reshape(image[i], shape, order='F') else: grp['image'] = image[i] label_vec = np.zeros(10) - label_vec[label[i]%10] = 1 + label_vec[label[i] % 10] = 1 grp['label'] = label_vec.astype(np.bool) bar.finish() f.close() data_id.close() return + def check_file(data_dir): if os.path.exists(data_dir): if os.path.isfile(os.path.join('data.hdf5')) and \ - os.path.isfile(os.path.join('id.txt')): + os.path.isfile(os.path.join('id.txt')): return True else: os.mkdir(data_dir) return False + def download_mnist(download_path): data_dir = os.path.join(download_path, 'mnist') @@ -62,42 +63,46 @@ def download_mnist(download_path): data_url = 'http://yann.lecun.com/exdb/mnist/' keys = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', - 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz'] - - for k in keys: - url = (data_url+k).format(**locals()) - target_path = os.path.join(data_dir, k) - cmd = ['curl', url, '-o', target_path] - print('Downloading ', k) - subprocess.call(cmd) - cmd = ['gzip', '-d', target_path] - print('Unzip ', k) - subprocess.call(cmd) - + 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz'] + # keys = [ 'train-labels-idx1-ubyte.gz'] + + # for k in keys: + # url = (data_url+k).format(**locals()) + # target_path = os.path.join(data_dir, k) + # print(target_path) + # cmd = ['curl', url, '-o', target_path] + # print('Downloading ', k) + # subprocess.call(cmd) + # # cmd = ['gzip', '-d', target_path] + # cmd = ['7z', 'e', target_path] + # print('Unzip ', k) + # # subprocess.call(cmd) + # print('OK ', k) num_mnist_train = 60000 num_mnist_test = 10000 - fd = open(os.path.join(data_dir,'train-images-idx3-ubyte')) - loaded = np.fromfile(file=fd,dtype=np.uint8) - train_image = loaded[16:].reshape((num_mnist_train,28,28,1)).astype(np.float) + fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte')) + loaded = np.fromfile(file=fd, dtype=np.uint8) + train_image = loaded[16:].reshape((num_mnist_train, 28, 28, 1)).astype(np.float) - fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte')) - loaded = np.fromfile(file=fd,dtype=np.uint8) + fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte')) + loaded = np.fromfile(file=fd, dtype=np.uint8) train_label = np.asarray(loaded[8:].reshape((num_mnist_train)).astype(np.float)) - fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte')) - loaded = np.fromfile(file=fd,dtype=np.uint8) - test_image = loaded[16:].reshape((num_mnist_test,28,28,1)).astype(np.float) + fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte')) + loaded = np.fromfile(file=fd, dtype=np.uint8) + test_image = loaded[16:].reshape((num_mnist_test, 28, 28, 1)).astype(np.float) - fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte')) - loaded = np.fromfile(file=fd,dtype=np.uint8) + fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte')) + loaded = np.fromfile(file=fd, dtype=np.uint8) test_label = np.asarray(loaded[8:].reshape((num_mnist_test)).astype(np.float)) prepare_h5py(train_image, train_label, test_image, test_label, data_dir) - for k in keys: - cmd = ['rm', '-f', os.path.join(data_dir, k[:-3])] - subprocess.call(cmd) + # for k in keys: + # cmd = ['rm', '-f', os.path.join(data_dir, k[:-3])] + # subprocess.call(cmd) + def download_svhn(download_path): data_dir = os.path.join(download_path, 'svhn') @@ -123,8 +128,9 @@ def svhn_loader(url, path): prepare_h5py(np.transpose(train_image, (3, 0, 1, 2)), train_label, np.transpose(test_image, (3, 0, 1, 2)), test_label, data_dir) - cmd = ['rm', '-f', os.path.join(data_dir, '*.mat')] - subprocess.call(cmd) + # cmd = ['rm', '-f', os.path.join(data_dir, '*.mat')] + # subprocess.call(cmd) + def download_cifar10(download_path): data_dir = os.path.join(download_path, 'cifar10') @@ -133,7 +139,7 @@ def download_cifar10(download_path): def unpickle(file): import cPickle with open(file, 'rb') as fo: - dict = cPickle.load(fo) + dict = cPickle.load(fo)#dict = cPickle.load(fo, encoding='bytes') return dict if check_file(data_dir): @@ -144,9 +150,9 @@ def unpickle(file): k = 'cifar-10-python.tar.gz' target_path = os.path.join(data_dir, k) print(target_path) - cmd = ['curl', data_url, '-o', target_path] + #cmd = ['curl', data_url, '-o', target_path] print('Downloading CIFAR10') - subprocess.call(cmd) + #subprocess.call(cmd) tarfile.open(target_path, 'r:gz').extractall(data_dir) num_cifar_train = 50000 @@ -156,18 +162,24 @@ def unpickle(file): train_image = [] train_label = [] for i in range(5): - fd = os.path.join(target_path, 'data_batch_'+str(i+1)) + fd = os.path.join(target_path, 'data_batch_' + str(i + 1)) dict = unpickle(fd) - train_image.append(dict['data']) - train_label.append(dict['labels']) - train_image = np.reshape(np.stack(train_image, axis=0), [num_cifar_train, 32*32*3]) + sys.stdout = Logger(r'/home/yc/PycharmProjects/SSGAN-Tensorflow/a.txt') + print(dict) + print('------------------') + + train_image.append(dict[b'data']) + train_label.append(dict[b'labels']) + + + train_image = np.reshape(np.stack(train_image, axis=0), [num_cifar_train, 32 * 32 * 3]) train_label = np.reshape(np.array(np.stack(train_label, axis=0)), [num_cifar_train]) fd = os.path.join(target_path, 'test_batch') dict = unpickle(fd) - test_image = np.reshape(dict['data'], [num_cifar_test, 32*32*3]) - test_label = np.reshape(dict['labels'], [num_cifar_test]) + test_image = np.reshape(dict[b'data'], [num_cifar_test, 32 * 32 * 3]) + test_label = np.reshape(dict[b'labels'], [num_cifar_test]) prepare_h5py(train_image, train_label, test_image, test_label, data_dir, [32, 32, 3]) @@ -176,14 +188,34 @@ def unpickle(file): cmd = ['rm', '-rf', os.path.join(data_dir, 'cifar-10-batches-py')] subprocess.call(cmd) + +import sys +import os + + +class Logger(object): + def __init__(self, filename="Default.log"): + self.terminal = sys.stdout + self.log = open(filename, "a") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + pass + + + + if __name__ == '__main__': args = parser.parse_args() - path = './datasets' + path = r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets" if not os.path.exists(path): os.mkdir(path) if 'MNIST' in args.datasets: - download_mnist('./datasets') + download_mnist(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets") if 'SVHN' in args.datasets: - download_svhn('./datasets') + download_svhn(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets") if 'CIFAR10' in args.datasets: - download_cifar10('./datasets') + download_cifar10(r"/home/yc/PycharmProjects/SSGAN-Tensorflow/datasets") diff --git a/trainer.py b/trainer.py index 3186400..779bdd1 100644 --- a/trainer.py +++ b/trainer.py @@ -33,7 +33,7 @@ def __init__(self, config, model, dataset, dataset_test): os.makedirs(self.train_dir) log.infov("Train Dir: %s", self.train_dir) - + print("step2") # --- input ops --- self.batch_size = config.batch_size @@ -205,9 +205,9 @@ def log_step_message(self, step, accuracy, d_loss, g_loss, ) def main(): - + print("step0") config, model, dataset_train, dataset_test = argparser(is_train=True) - + print("step1") trainer = Trainer(config, model, dataset_train, dataset_test) log.warning("dataset: %s, learning_rate_g: %f, learning_rate_d: %f", diff --git a/visualize_training.py b/visualize_training.py index 88e0ad8..ced515d 100644 --- a/visualize_training.py +++ b/visualize_training.py @@ -27,5 +27,5 @@ II.append(I) II = np.stack(II) -print II.shape +print(II.shape) imageio.mimsave(args.output_file, II, fps=5)