From 8462813343518aaa5a333c12113e1619ddcec4ee Mon Sep 17 00:00:00 2001 From: draplater Date: Sat, 10 Feb 2018 01:20:15 +0800 Subject: [PATCH] autobatch for conv2d (#1216) * autobatch for conv2d * Fixed random normal and made mean/stddev an option (#1220) * Argmax and straight-through estimator (#1208) * Initial argmax * Add missing file * Python implementation of argmax * straight-through * Use as_vector * Fixed on GPU * Change API * add __gitversion__ info to python build (#1229) * Added unit test --- CMakeLists.txt | 2 ++ dynet/nodes-conv2d.cc | 22 ++++++++++++++++++++++ dynet/nodes-conv2d.h | 9 +++++++++ dynet/sig.h | 1 + tests/test-nodes.cc | 34 ++++++++++++++++++++++++++++++++++ 5 files changed, 68 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7bfc5fb49..f592036ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,6 +91,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR} function(find_cudnn) if (DEFINED ENV{CUDNN_ROOT} AND NOT DEFINED CUDNN_ROOT) # use env variable if not defined set(CUDNN_ROOT $ENV{CUDNN_ROOT}) + elseif (DEFINED CUDA_TOOLKIT_ROOT_DIR AND NOT DEFINED CUDNN_ROOT) # use env variable if not defined + set(CUDNN_ROOT ${CUDA_TOOLKIT_ROOT_DIR}) endif() # set(CUDNN_ROOT /usr/local/cuda CACHE PATH "CUDNN root path") find_path(CUDNN_INCLUDE_DIRS cudnn.h diff --git a/dynet/nodes-conv2d.cc b/dynet/nodes-conv2d.cc index 024114fa2..4ce7b6002 100644 --- a/dynet/nodes-conv2d.cc +++ b/dynet/nodes-conv2d.cc @@ -70,6 +70,28 @@ Dim Conv2D::dim_forward(const vector& xs) const { return Dim(output_shape, bs); } +int Conv2D::autobatch_sig(const ComputationGraph & cg, SigMap &sm) const { + Sig s(nt::conv2d); + // Note that autobatching will only occur when inputs are of batch size one + // TODO: remove this restriction, allowing for combining batched inputs + if(dim.bd == 1) { + s.add_dim(cg.nodes[args[0]]->dim); // the input + s.add_node(args[1]); // the filter + s.add_int(static_cast(is_valid)); + s.add_int(stride[0]); + s.add_int(stride[1]); + return sm.get_idx(s); + } else { + return 0; + } +} + +std::vector Conv2D::autobatch_concat(const ComputationGraph & cg) const { + vector ret(args.size(), 0); + if (dim.bd == 1) { ret[0] = 1; } + return ret; +} + // size_t Conv2D::aux_storage_size() const { // vector input_size(arity()); // for (unsigned i = 0; i < arity(); ++i) { diff --git a/dynet/nodes-conv2d.h b/dynet/nodes-conv2d.h index 891d2b4d3..9028f1cd4 100644 --- a/dynet/nodes-conv2d.h +++ b/dynet/nodes-conv2d.h @@ -22,6 +22,15 @@ struct Conv2D: public Node { const bool padding_type = true) : Node(a), stride(s), is_valid(padding_type) {} virtual bool supports_multibatch() const override { return true; } + virtual int autobatch_sig(const ComputationGraph &cg, SigMap &sm) const override; + virtual std::vector autobatch_concat(const ComputationGraph & cg) const override; + virtual void autobatch_reshape(const ComputationGraph & cg, + const std::vector & batch_ids, + const std::vector & concat, + std::vector& xs, + Tensor& fx) const override { + autobatch_reshape_concatonly(cg, batch_ids, concat, xs, fx); + } DYNET_NODE_DEFINE_DEV_IMPL() const std::vector stride; const bool is_valid; diff --git a/dynet/sig.h b/dynet/sig.h index acdc33d1f..160734ff4 100644 --- a/dynet/sig.h +++ b/dynet/sig.h @@ -19,6 +19,7 @@ namespace dynet { COMPLEX, affine, matmul, vanilla_lstm_gates, vanilla_lstm_h, vanilla_lstm_c, + conv2d }; } diff --git a/tests/test-nodes.cc b/tests/test-nodes.cc index 35cf43a6d..e6b0b563f 100644 --- a/tests/test-nodes.cc +++ b/tests/test-nodes.cc @@ -1721,6 +1721,40 @@ BOOST_AUTO_TEST_CASE( conv2d_valid_gradient ) { BOOST_CHECK(check_grad(mod, z, 0)); } +// Expression log_softmax(const Expression& x); +BOOST_AUTO_TEST_CASE( conv2d_autobatch_gradient ) { + auto autobatch_cache = dynet::autobatch_flag; + dynet::autobatch_flag = 1; + dynet::ComputationGraph cg; + Parameter param_kernel = mod.add_parameters({2, 2, 2, 3}); + std::vector param_kernel_vals = {.011f, .022f, .033f, .012f, .022f, .032f, .013f, .023f, .033f, + .111f, -.122f, -.033f, -.112f, -.022f, -.132f, -.113f, -.123f, -.133f, + .211f, .222f, .233f, .212f, .222f, .232f + }; + TensorTools::set_elements(param_kernel.get_storage().values, param_kernel_vals); + Expression kernel = parameter(cg, param_kernel); + vector stride = {3, 3}; bool is_valid = true; + std::vector conv2d_vals1(50 * 50 * 2), conv2d_vals2(50 * 50 * 2); + for (unsigned i = 0; i < conv2d_vals1.size(); ++i) { + conv2d_vals1[i] = i * 0.011f + (i + 1) * 0.001f; + conv2d_vals2[i] = i * 0.015f + (i + 1) * -0.001f; + } + vector zs; + { + Expression x = input(cg, Dim({50, 50, 2}), conv2d_vals1); + Expression y = conv2d(x, kernel, stride, is_valid); + zs.push_back(to_scalar(y)); + } + { + Expression x = input(cg, Dim({50, 50, 2}), conv2d_vals2); + Expression y = conv2d(x, kernel, stride, is_valid); + zs.push_back(to_scalar(y)); + } + Expression z = sum(zs); + BOOST_CHECK(check_grad(mod, z, 0)); + dynet::autobatch_flag = autobatch_cache; +} + // Expression conv2d(const Expression& x ,const Expression& f, const std::vector& stride, bool is_valid); BOOST_AUTO_TEST_CASE( conv2d_valid_singlefilter_gradient ) { dynet::ComputationGraph cg;