Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update combine_data script for OpenCV 3.0 #29

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ th -ldisplay.start 8000 0.0.0.0
```
Then open `http://(hostname):(port)/` in your browser to load the remote desktop.

L1 error is plotted to the display by default. Set the environment variable `display_plot` to a comma-seperated list of values `errL1`, `errG` and `errD` to visualize the L1, generator, and descriminator error respectively. For example, to plot only the generator and descriminator errors to the display instead of the default L1 error, set `display_plot="errG,errD"`.

## Citation
If you use this code for your research, please cite our paper <a href="https://arxiv.org/pdf/1611.07004v1.pdf">Image-to-Image Translation Using Conditional Adversarial Networks</a>:

Expand Down
4 changes: 2 additions & 2 deletions data/combine_A_and_B.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
if args.use_AB:
name_AB = name_AB.replace('_A.', '.') # remove _A
path_AB = os.path.join(img_fold_AB, name_AB)
im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR)
im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR)
im_A = cv2.imread(path_A, cv2.IMREAD_COLOR)
im_B = cv2.imread(path_B, cv2.IMREAD_COLOR)
im_AB = np.concatenate([im_A, im_B], 1)
cv2.imwrite(path_AB, im_AB)

22 changes: 11 additions & 11 deletions data/donkey_folder.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ local trainCache = paths.concat(cache, cache_prefix .. '_trainCache.t7')
--------------------------------------------------------------------------------------------
local input_nc = opt.input_nc -- input channels
local output_nc = opt.output_nc
local loadSize = {input_nc, opt.loadSize}
local sampleSize = {input_nc, opt.fineSize}
local loadSize = {input_nc, opt.imgWidth, opt.imgHeight}
local sampleSize = {input_nc, opt.imgWidth, opt.imgHeight}

local preprocessAandB = function(imA, imB)
imA = image.scale(imA, loadSize[2], loadSize[2])
imB = image.scale(imB, loadSize[2], loadSize[2])
imA = image.scale(imA, loadSize[2], loadSize[3])
imB = image.scale(imB, loadSize[2], loadSize[3])
local perm = torch.LongTensor{3, 2, 1}
imA = imA:index(1, perm)--:mul(256.0): brg, rgb
imA = imA:mul(2):add(-1)
Expand All @@ -52,7 +52,7 @@ local preprocessAandB = function(imA, imB)


local oW = sampleSize[2]
local oH = sampleSize[2]
local oH = sampleSize[3]
local iH = imA:size(2)
local iW = imA:size(3)

Expand Down Expand Up @@ -80,10 +80,10 @@ end

local function loadImageChannel(path)
local input = image.load(path, 3, 'float')
input = image.scale(input, loadSize[2], loadSize[2])
input = image.scale(input, loadSize[2], loadSize[3])

local oW = sampleSize[2]
local oH = sampleSize[2]
local oH = sampleSize[3]
local iH = input:size(2)
local iW = input:size(3)

Expand Down Expand Up @@ -161,7 +161,7 @@ print('trainCache', trainCache)
-- trainLoader = torch.load(trainCache)
-- trainLoader.sampleHookTrain = trainHook
-- trainLoader.loadSize = {input_nc, opt.loadSize, opt.loadSize}
-- trainLoader.sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]}
-- trainLoader.sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[3]}
-- trainLoader.serial_batches = opt.serial_batches
-- trainLoader.split = 100
--else
Expand All @@ -170,8 +170,8 @@ print('Creating train metadata')
print('serial batch:, ', opt.serial_batches)
trainLoader = dataLoader{
paths = {opt.data},
loadSize = {input_nc, loadSize[2], loadSize[2]},
sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]},
loadSize = {input_nc, loadSize[2], loadSize[3]},
sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[3]},
split = 100,
serial_batches = opt.serial_batches,
verbose = true
Expand All @@ -189,4 +189,4 @@ do
local nClasses = #trainLoader.classes
assert(class:max() <= nClasses, "class logic has error")
assert(class:min() >= 1, "class logic has error")
end
end
21 changes: 11 additions & 10 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ opt = {
DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc)
batchSize = 1, -- # images in batch
loadSize = 256, -- scale images to this size
fineSize = 256, -- then crop to this size
imgWidth = 256, -- then crop to this size. Both should be multiples of 32...
imgHeight = 512,
flip=0, -- horizontal mirroring data augmentation
display = 1, -- display samples while training. 0 = false
display_id = 200, -- display window id.
Expand Down Expand Up @@ -69,8 +70,8 @@ else
end
----------------------------------------------------------------------------

local input = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
local target = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
local input = torch.FloatTensor(opt.batchSize,3,opt.imgHeight,opt.imgWidth)
local target = torch.FloatTensor(opt.batchSize,3,opt.imgHeight,opt.imgWidth)

print('checkpoints_dir', opt.checkpoints_dir)
local netG = util.load(paths.concat(opt.checkpoints_dir, opt.netG_name .. '.t7'), opt)
Expand Down Expand Up @@ -129,18 +130,18 @@ for n=1,math.floor(opt.how_many/opt.batchSize) do
print(output:size())
print(target:size())
for i=1, opt.batchSize do
image.save(paths.concat(image_dir,'input',filepaths_curr[i]), image.scale(input[i],input[i]:size(2),input[i]:size(3)/opt.aspect_ratio))
image.save(paths.concat(image_dir,'output',filepaths_curr[i]), image.scale(output[i],output[i]:size(2),output[i]:size(3)/opt.aspect_ratio))
image.save(paths.concat(image_dir,'target',filepaths_curr[i]), image.scale(target[i],target[i]:size(2),target[i]:size(3)/opt.aspect_ratio))
image.save(paths.concat(image_dir,'input',filepaths_curr[i]), input[i])--image.scale(input[i],input[i]:size(3),input[i]:size(2))) --/opt.aspect_ratio))
image.save(paths.concat(image_dir,'output',filepaths_curr[i]), output[i])--image.scale(output[i],output[i]:size(3),output[i]:size(2)))--/opt.aspect_ratio))
image.save(paths.concat(image_dir,'target',filepaths_curr[i]), target[i])--image.scale(target[i],target[i]:size(3),target[i]:size(2)))--/opt.aspect_ratio))
end
print('Saved images to: ', image_dir)

if opt.display then
if opt.preprocess == 'regular' then
disp = require 'display'
disp.image(util.scaleBatch(input,100,100),{win=opt.display_id, title='input'})
disp.image(util.scaleBatch(output,100,100),{win=opt.display_id+1, title='output'})
disp.image(util.scaleBatch(target,100,100),{win=opt.display_id+2, title='target'})
disp.image(util.scaleBatch(input,512,256),{win=opt.display_id, title='input'})
disp.image(util.scaleBatch(output,512,256),{win=opt.display_id+1, title='output'})
disp.image(util.scaleBatch(target,512,256),{win=opt.display_id+2, title='target'})

print('Displayed images')
end
Expand All @@ -164,4 +165,4 @@ for i=1, #filepaths do
io.write('</tr>')
end

io.write('</table>')
io.write('</table>')
75 changes: 56 additions & 19 deletions train.lua
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ opt = {
DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc)
batchSize = 1, -- # images in batch
loadSize = 286, -- scale images to this size
fineSize = 256, -- then crop to this size
imgWidth = 256, -- then crop to this size. Both should be multiples of 32...
imgHeight = 512,
ngf = 64, -- # of gen filters in first conv layer
ndf = 64, -- # of discrim filters in first conv layer
input_nc = 3, -- # of input image channels
Expand All @@ -27,6 +28,7 @@ opt = {
flip = 1, -- if flip the images for data argumentation
display = 1, -- display samples while training. 0 = false
display_id = 10, -- display window id.
display_plot = 'errL1', -- which loss values to plot over time. Accepted values include a comma seperated list of: errL1, errG, and errD
gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
name = '', -- name of the experiment, should generally be passed on the command line
which_direction = 'AtoB', -- AtoB or BtoA
Expand All @@ -36,7 +38,7 @@ opt = {
save_epoch_freq = 50, -- save a model every save_epoch_freq epochs (does not overwrite previously saved models)
save_latest_freq = 5000, -- save the latest model every latest_freq sgd iterations (overwrites the previous latest model)
print_freq = 50, -- print the debug information every print_freq iterations
display_freq = 100, -- display the current results every display_freq iterations
display_freq = 20, -- display the current results every display_freq iterations
save_display_freq = 5000, -- save the current display of results every save_display_freq_iterations
continue_train=0, -- if continue training, load the latest model: 1: true, 0: false
serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly
Expand Down Expand Up @@ -165,11 +167,11 @@ optimStateD = {
beta1 = opt.beta1,
}
----------------------------------------------------------------------------
local real_A = torch.Tensor(opt.batchSize, input_nc, opt.fineSize, opt.fineSize)
local real_B = torch.Tensor(opt.batchSize, output_nc, opt.fineSize, opt.fineSize)
local fake_B = torch.Tensor(opt.batchSize, output_nc, opt.fineSize, opt.fineSize)
local real_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.fineSize, opt.fineSize)
local fake_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.fineSize, opt.fineSize)
local real_A = torch.Tensor(opt.batchSize, input_nc, opt.imgWidth, opt.imgHeight)
local real_B = torch.Tensor(opt.batchSize, output_nc, opt.imgWidth, opt.imgHeight)
local fake_B = torch.Tensor(opt.batchSize, output_nc, opt.imgWidth, opt.imgHeight)
local real_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.imgWidth, opt.imgHeight)
local fake_AB = torch.Tensor(opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.imgWidth, opt.imgHeight)
local errD, errG, errL1 = 0, 0, 0
local epoch_tm = torch.Timer()
local tm = torch.Timer()
Expand Down Expand Up @@ -314,6 +316,25 @@ file = torch.DiskFile(paths.concat(opt.checkpoints_dir, opt.name, 'opt.txt'), 'w
file:writeObject(opt)
file:close()

-- parse diplay_plot string into table
opt.display_plot = string.split(string.gsub(opt.display_plot, "%s+", ""), ",")
for k, v in ipairs(opt.display_plot) do
if not util.containsValue({"errG", "errD", "errL1"}, v) then
error(string.format('bad display_plot value "%s"', v))
end
end

-- display plot config
local plot_config = {
title = "Loss over time",
labels = {"epoch", unpack(opt.display_plot)},
ylabel = "loss",
}

-- display plot vars
local plot_data = {}
local plot_win

local counter = 0
for epoch = 1, opt.niter do
epoch_tm:reset()
Expand All @@ -328,22 +349,22 @@ for epoch = 1, opt.niter do

-- (2) Update G network: maximize log(D(x,G(x))) + L1(y,G(x))
optim.adam(fGx, parametersG, optimStateG)

-- display
counter = counter + 1
if counter % opt.display_freq == 0 and opt.display then
createRealFake()
if opt.preprocess == 'colorization' then
local real_A_s = util.scaleBatch(real_A:float(),100,100)
local fake_B_s = util.scaleBatch(fake_B:float(),100,100)
local real_B_s = util.scaleBatch(real_B:float(),100,100)
local real_A_s = util.scaleBatch(real_A:float(),512,256)
local fake_B_s = util.scaleBatch(fake_B:float(),512,256)
local real_B_s = util.scaleBatch(real_B:float(),512,256)
disp.image(util.deprocessL_batch(real_A_s), {win=opt.display_id, title=opt.name .. ' input'})
disp.image(util.deprocessLAB_batch(real_A_s, fake_B_s), {win=opt.display_id+1, title=opt.name .. ' output'})
disp.image(util.deprocessLAB_batch(real_A_s, real_B_s), {win=opt.display_id+2, title=opt.name .. ' target'})
else
disp.image(util.deprocess_batch(util.scaleBatch(real_A:float(),100,100)), {win=opt.display_id, title=opt.name .. ' input'})
disp.image(util.deprocess_batch(util.scaleBatch(fake_B:float(),100,100)), {win=opt.display_id+1, title=opt.name .. ' output'})
disp.image(util.deprocess_batch(util.scaleBatch(real_B:float(),100,100)), {win=opt.display_id+2, title=opt.name .. ' target'})
disp.image(util.deprocess_batch(util.scaleBatch(real_A:float(),512,256)), {win=opt.display_id, title=opt.name .. ' input'})
disp.image(util.deprocess_batch(util.scaleBatch(fake_B:float(),512,256)), {win=opt.display_id+1, title=opt.name .. ' output'})
disp.image(util.deprocess_batch(util.scaleBatch(real_B:float(),512,256)), {win=opt.display_id+2, title=opt.name .. ' target'})
end
end

Expand Down Expand Up @@ -377,14 +398,30 @@ for epoch = 1, opt.niter do
opt.serial_batches=serial_batches
end

-- logging
-- logging and display plot
if counter % opt.print_freq == 0 then
local loss = {errG=errG and errG or -1, errD=errD and errD or -1, errL1=errL1 and errL1 or -1}
local curItInBatch = ((i-1) / opt.batchSize)
local totalItInBatch = math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize)
print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f '
.. ' Err_G: %.4f Err_D: %.4f ErrL1: %.4f'):format(
epoch, ((i-1) / opt.batchSize),
math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize),
epoch, curItInBatch, totalItInBatch,
tm:time().real / opt.batchSize, data_tm:time().real / opt.batchSize,
errG and errG or -1, errD and errD or -1, errL1 and errL1 or -1))
errG, errD, errL1))

local plot_vals = { epoch + curItInBatch / totalItInBatch }
for k, v in ipairs(opt.display_plot) do
if loss[v] ~= nil then
plot_vals[#plot_vals + 1] = loss[v]
end
end

-- update display plot
if opt.display then
table.insert(plot_data, plot_vals)
plot_config.win = plot_win
plot_win = disp.plot(plot_data, plot_config)
end
end

-- save latest model
Expand All @@ -409,4 +446,4 @@ for epoch = 1, opt.niter do
epoch, opt.niter, epoch_tm:time().real))
parametersD, gradParametersD = netD:getParameters() -- reflatten the params and get them
parametersG, gradParametersG = netG:getParameters()
end
end
7 changes: 7 additions & 0 deletions util/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,11 @@ function util.cudnn(net)
return cudnn_convert_custom(net, cudnn)
end

function util.containsValue(table, value)
for k, v in pairs(table) do
if v == value then return true end
end
return false
end

return util