Skip to content

Learning Goal-Oriented Visual Dialog via Tempered Policy Gradient (SLT 2018) (IJCAIw 2018)

License

Notifications You must be signed in to change notification settings

ruizhaogit/GuessWhat-TemperedPolicyGradient

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GuessWhat-TemperedPolicyGradient

The original GuessWhat?! Game repo is based on TensorFlow and available at:
https://github.com/GuessWhatGame/guesswhat.

This repo is based on Torch7 (Lua) and improves the performance by 14% using advanced RNN structures and Tempered Policy Gradient.
The paper is avaliable at https://arxiv.org/abs/1807.00737

The code was developed by Rui Zhao (Siemens AG & Ludwig Maximilian University of Munich).
The implementation is tested on Ubuntu 14.04 using a single GPU with 12GB memory.

  1. Installation:
  • Our code is implemented in Torch (Lua). Installation instructions are as follows:
git clone https://github.com/ruizhaogit/GuessWhat-TemperedPolicyGradient.git
  • The data preprocessing python script needs the packages including: numpy, h5py, nltk, json_lines, and tqdm. To install these packages, you can use the "pip" tool:
pip install numpy
pip install h5py
pip install nltk
pip install json_lines
pip install tqdm

To use the tokenizer in nltk, you need download the necessary packages in python:

import nltk
nltk.download('punkt')
git clone https://github.com/torch/distro.git ~/torch --recursive
cd ~/torch; bash install-deps;
./install.sh
source ~/.bashrc

sudo apt-get install libprotobuf-dev protobuf-compiler
luarocks install loadcaffe
luarocks install dp
luarocks install nngraph
luarocks install lua-cjson

git clone https://github.com/Element-Research/rnn.git
cd rnn
luarocks make rocks/rnn-scm-1.rockspec

luarocks install luabitop
sudo apt-get install libhdf5-serial-dev hdf5-tools
git clone https://github.com/deepmind/torch-hdf5
cd torch-hdf5
luarocks make hdf5-0-0.rockspec LIBHDF5_LIBDIR="/usr/lib/x86_64-linux-gnu/"

Installation instructions for torch-hdf5 are given here.

  1. Download datasets
wget https://s3-us-west-2.amazonaws.com/guess-what/guesswhat.train.jsonl.gz -P data/ 
wget https://s3-us-west-2.amazonaws.com/guess-what/guesswhat.valid.jsonl.gz -P data/  
wget https://s3-us-west-2.amazonaws.com/guess-what/guesswhat.test.jsonl.gz -P data/  
gunzip data/*.jsonl.gz
wget http://images.cocodataset.org/zips/train2014.zip -P data/  
wget http://images.cocodataset.org/zips/val2014.zip -P data/  
unzip 'data/*.zip' -d data/images/ | awk 'BEGIN {ORS=" "} {if(NR%1000==0)print "."}'  
rm data/*.zip  
  1. Preprocess data
wget https://gist.githubusercontent.com/ksimonyan/211839e770f7b538e2d8/raw/ded9363bd93ec0c770134f4e387d8aaaaa2407ce/VGG_ILSVRC_16_layers_deploy.prototxt -P model/vgg16/  
wget http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel -P model/vgg16/
th data/preprocessImages.lua
  • Preprocessing the GuessWhat?! dataset, including parsing QAs using NLTK, building the dictionary, calculating the spatial features, encoding the category information:
python2.7 data/preprocessQAs.py
  1. Train and evaluate the models:
  • To train the Oracle, Guesser, and the QGen (Question-Generator) model, simply run the lua scripts separately:
th oracle.lua --learningRate 0.0001 --epochs 30 --earlyStopThresh 10 --learningRateDecay 1e-7
th guesser.lua --learningRate 0.0001 --epochs 30 --earlyStopThresh 10 --learningRateDecay 1e-7
th qgen.lua --learningRate 0.001 --epochs 30 --earlyStopThresh 10 --learningRateDecay 0.5
  • For the reinforcement learning part, if you want to use the standard REINFORCE method, please run:
th reinforce.lua --temperatures '{1.0}' --epochs 80 --earlyStopThresh 20 --learningRate 0.001
  • If you want to use Single-TPG, please run:
th reinforce.lua --temperatures '{1.5}' --epochs 80 --earlyStopThresh 20 --learningRate 0.001
  • If you want to use Parallel-TPG, please run:
th reinforce.lua --temperatures '{1.0, 1.5}' --epochs 80 --earlyStopThresh 20 --learningRate 0.001
  • At last, if you want to use the Dynamic-TPG, please run:
th reinforce.lua --DynamicTPG --tempMin 0.5 --tempMax 1.5 --epochs 80 --earlyStopThresh 20 --learningRate 0.001
  • After training, we obtained the following results:
Method Accuracy
REINFORCE 69.66%
Single-TPG 69.76%
Parallel-TPG 73.86%
Dynamic-TPG 74.31%
  1. Citation:

Citation of the IJCAI workshop paper:

@article{zhao2018improving,
  title={Improving Goal-Oriented Visual Dialog Agents via Advanced Recurrent Nets with Tempered Policy Gradient},
  author={Zhao, Rui and Tresp, Volker},
  journal={arXiv preprint arXiv:1807.00737},
  year={2018}
}

Citation of the published IEEE SLT paper:

@inproceedings{zhao2018learning,
  title={Learning goal-oriented visual dialog via tempered policy gradient},
  author={Zhao, Rui and Tresp, Volker},
  booktitle={2018 IEEE Spoken Language Technology Workshop (SLT)},
  pages={868--875},
  year={2018},
  organization={IEEE}
}

About

Learning Goal-Oriented Visual Dialog via Tempered Policy Gradient (SLT 2018) (IJCAIw 2018)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published