Skip to content

Commit

Permalink
autobatch for conv2d (#1216)
Browse files Browse the repository at this point in the history
* autobatch for conv2d

* Fixed random normal and made mean/stddev an option (#1220)

* Argmax and straight-through estimator (#1208)

* Initial argmax

* Add missing file

* Python implementation of argmax

* straight-through

* Use as_vector

* Fixed on GPU

* Change API

* add __gitversion__ info to python build (#1229)

* Added unit test
  • Loading branch information
draplater authored and neubig committed Feb 9, 2018
1 parent e7ad9b8 commit 8462813
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}
function(find_cudnn)
if (DEFINED ENV{CUDNN_ROOT} AND NOT DEFINED CUDNN_ROOT) # use env variable if not defined
set(CUDNN_ROOT $ENV{CUDNN_ROOT})
elseif (DEFINED CUDA_TOOLKIT_ROOT_DIR AND NOT DEFINED CUDNN_ROOT) # use env variable if not defined
set(CUDNN_ROOT ${CUDA_TOOLKIT_ROOT_DIR})
endif()
# set(CUDNN_ROOT /usr/local/cuda CACHE PATH "CUDNN root path")
find_path(CUDNN_INCLUDE_DIRS cudnn.h
Expand Down
22 changes: 22 additions & 0 deletions dynet/nodes-conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,28 @@ Dim Conv2D::dim_forward(const vector<Dim>& xs) const {
return Dim(output_shape, bs);
}

int Conv2D::autobatch_sig(const ComputationGraph & cg, SigMap &sm) const {
Sig s(nt::conv2d);
// Note that autobatching will only occur when inputs are of batch size one
// TODO: remove this restriction, allowing for combining batched inputs
if(dim.bd == 1) {
s.add_dim(cg.nodes[args[0]]->dim); // the input
s.add_node(args[1]); // the filter
s.add_int(static_cast<int>(is_valid));
s.add_int(stride[0]);
s.add_int(stride[1]);
return sm.get_idx(s);
} else {
return 0;
}
}

std::vector<int> Conv2D::autobatch_concat(const ComputationGraph & cg) const {
vector<int> ret(args.size(), 0);
if (dim.bd == 1) { ret[0] = 1; }
return ret;
}

// size_t Conv2D::aux_storage_size() const {
// vector<unsigned> input_size(arity());
// for (unsigned i = 0; i < arity(); ++i) {
Expand Down
9 changes: 9 additions & 0 deletions dynet/nodes-conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ struct Conv2D: public Node {
const bool padding_type = true)
: Node(a), stride(s), is_valid(padding_type) {}
virtual bool supports_multibatch() const override { return true; }
virtual int autobatch_sig(const ComputationGraph &cg, SigMap &sm) const override;
virtual std::vector<int> autobatch_concat(const ComputationGraph & cg) const override;
virtual void autobatch_reshape(const ComputationGraph & cg,
const std::vector<VariableIndex> & batch_ids,
const std::vector<int> & concat,
std::vector<const Tensor*>& xs,
Tensor& fx) const override {
autobatch_reshape_concatonly(cg, batch_ids, concat, xs, fx);
}
DYNET_NODE_DEFINE_DEV_IMPL()
const std::vector<unsigned> stride;
const bool is_valid;
Expand Down
1 change: 1 addition & 0 deletions dynet/sig.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace dynet {
COMPLEX,
affine, matmul,
vanilla_lstm_gates, vanilla_lstm_h, vanilla_lstm_c,
conv2d
};
}

Expand Down
34 changes: 34 additions & 0 deletions tests/test-nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1721,6 +1721,40 @@ BOOST_AUTO_TEST_CASE( conv2d_valid_gradient ) {
BOOST_CHECK(check_grad(mod, z, 0));
}

// Expression log_softmax(const Expression& x);
BOOST_AUTO_TEST_CASE( conv2d_autobatch_gradient ) {
auto autobatch_cache = dynet::autobatch_flag;
dynet::autobatch_flag = 1;
dynet::ComputationGraph cg;
Parameter param_kernel = mod.add_parameters({2, 2, 2, 3});
std::vector<float> param_kernel_vals = {.011f, .022f, .033f, .012f, .022f, .032f, .013f, .023f, .033f,
.111f, -.122f, -.033f, -.112f, -.022f, -.132f, -.113f, -.123f, -.133f,
.211f, .222f, .233f, .212f, .222f, .232f
};
TensorTools::set_elements(param_kernel.get_storage().values, param_kernel_vals);
Expression kernel = parameter(cg, param_kernel);
vector<unsigned> stride = {3, 3}; bool is_valid = true;
std::vector<float> conv2d_vals1(50 * 50 * 2), conv2d_vals2(50 * 50 * 2);
for (unsigned i = 0; i < conv2d_vals1.size(); ++i) {
conv2d_vals1[i] = i * 0.011f + (i + 1) * 0.001f;
conv2d_vals2[i] = i * 0.015f + (i + 1) * -0.001f;
}
vector<Expression> zs;
{
Expression x = input(cg, Dim({50, 50, 2}), conv2d_vals1);
Expression y = conv2d(x, kernel, stride, is_valid);
zs.push_back(to_scalar(y));
}
{
Expression x = input(cg, Dim({50, 50, 2}), conv2d_vals2);
Expression y = conv2d(x, kernel, stride, is_valid);
zs.push_back(to_scalar(y));
}
Expression z = sum(zs);
BOOST_CHECK(check_grad(mod, z, 0));
dynet::autobatch_flag = autobatch_cache;
}

// Expression conv2d(const Expression& x ,const Expression& f, const std::vector<unsigned>& stride, bool is_valid);
BOOST_AUTO_TEST_CASE( conv2d_valid_singlefilter_gradient ) {
dynet::ComputationGraph cg;
Expand Down

0 comments on commit 8462813

Please sign in to comment.