Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generators should be indexed from 1 #1

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@ ENV/

# mypy
.mypy_cache/

*.gz
*.t7

out
5 changes: 3 additions & 2 deletions began.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def began_disc(x, name, **kwargs):
figs[i].canvas.draw()

plt.savefig(out_dir + fig_names[i].format(it / 1000), bbox_inches='tight')

if PLT_CLOSE == 1:
plt.close()
# Run evaluation functions
for func in eval_funcs:
func(it, img_generator)
Expand Down Expand Up @@ -178,4 +179,4 @@ def began_disc(x, name, **kwargs):
d_dec = DCGAN_G(n_in=dim_h, last_act=tf.sigmoid, bn=False)

train_began(data, g_net, d_enc, d_dec, name=out_name, dim_z=dim_z, batch_size=args.batchsize, lr=args.lr,
eval_funcs=[lambda it, gen: eval_images_naive(it, gen, data)])
eval_funcs=[lambda it, gen: eval_images_naive(it, gen, data)])
17 changes: 12 additions & 5 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
import argparse
import numpy as np
import io

import matplotlib.cm as cm

# Global configs
PRNT_INTERVAL = 100
EVAL_INTERVAL = 2000
SHOW_FIG_INTERVAL = 1000
SAVE_INTERVAL = 4000

PLT_CLOSE = 1 #TODO set it to 1 for 1D exp as seaborn needs it
DATASETS = ['mnist', 'celeba']


Expand Down Expand Up @@ -45,7 +46,7 @@ def create_dirs(name, g_name, d_name, hyperparams=None):
def check_dataset_type(shape):
assert(shape)

if len(shape) == 1:
if len(shape) == 1 or shape == (1,1):
return 'synthetic'
elif shape[2] == 1:
assert(shape[0] == 28 and shape[1] == 28)
Expand Down Expand Up @@ -93,7 +94,13 @@ def scatter(samples, figId=None, retBytes=False, xlim=None, ylim=None):
fig = plt.figure(figId)
fig.clear()

n_gen = 8 #TODO
colors = cm.rainbow(np.linspace(0, 1, n_gen)) #TODO
colors = np.repeat(colors, len(samples[:,0])/n_gen, 0) #TODO

#plt.scatter(samples[:,0], samples[:,1], c = colors, alpha=0.1) #TODO
plt.scatter(samples[:,0], samples[:,1], alpha=0.1)

if xlim:
plt.xlim(xlim[0], xlim[1])
if ylim:
Expand All @@ -112,7 +119,7 @@ def scatter(samples, figId=None, retBytes=False, xlim=None, ylim=None):
def parse_args(batchsize=128, lr=1e-5, additional_args=[]):
parser = argparse.ArgumentParser()

parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--gpu', type=int, default=0) #TODO
parser.add_argument('--batchsize', type=int, default=batchsize)
parser.add_argument('--datasets', choices=DATASETS, default=DATASETS[0])
parser.add_argument('--lr', type=float, default=lr)
Expand All @@ -129,4 +136,4 @@ def parse_args(batchsize=128, lr=1e-5, additional_args=[]):
def set_gpu(gpu_id):
print "Override GPU setting: gpu={}".format(gpu_id)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_id)
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(gpu_id)
2 changes: 1 addition & 1 deletion datasets/data_celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ def plot(self, img_generator, fig_id=None):

for i in range(16):
cv2.imshow('image', ims[i][:, :, (2,1,0)])
cv2.waitKey(0)
cv2.waitKey(0)
71 changes: 71 additions & 0 deletions datasets/data_generator.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#Uses space separated specs.txt to output input.txt.
#specs need to have a header row with four columns

require 'torch'
require 'nn'
require 'optim'
require 'pl'

opt={
folder='',
data_name='',
num_samples=768000,
filename='specs.txt',
out_name='input.txt',
t7_filename='data.t7'
}

function tablelength(T)
local count = 0
for _ in pairs(T) do count = count + 1 end
return count
end

for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
print(opt)

local d=data.read(paths.concat(opt.folder,opt.data_name,opt.filename))
local ndim=1
local colnames=d:column_names()
if tablelength(colnames)==4 then
ndim=2
elseif tablelength(colnames)==3 then
ndim=1
end
print('ndim '..ndim)

local ncentres=tablelength( d:column_by_name(colnames[1]) )

local centres=torch.Tensor(ncentres,ndim)
local std_dev=torch.Tensor(ncentres)
local densities=torch.Tensor(ncentres)

for i=1,ncentres do
for j=1,ndim do
centres[i][j]=tonumber(d:column_by_name(colnames[j])[i])
end
std_dev[i]=tonumber( d:column_by_name(colnames[ndim+1])[i] )
densities[i]=tonumber( d:column_by_name(colnames[ndim+2])[i] )
end

paths.mkdir(paths.concat(opt.folder,opt.data_name))
rand_indices=torch.multinomial(densities, opt.num_samples,true )
local data=torch.Tensor(opt.num_samples,ndim)
file=io.open(paths.concat(opt.folder,opt.data_name,opt.out_name) ,'w')
io.output(file)
for i=1,opt.num_samples do
local k=rand_indices[i]
local point=torch.Tensor(ndim)
for j=1,ndim do
point[j]=torch.normal(0,std_dev[k])+centres[k][j]
end
data[i]=point
if ndim==1 then
io.write(string.format('%d %f\n',0,point[1]))
elseif ndim==2 then
io.write(string.format('%d %f %f\n',0,point[1],point[2]))
end
end
io.close(file)

torch.save( paths.concat(opt.folder,opt.data_name,opt.t7_filename),data)
Loading