diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index 9eafe941c1..8f6bdd240e 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -4637,11 +4637,16 @@ namespace dlib // ---------------------------------------------------------------------------------------- - template + template + struct float_constant { + static constexpr T value = val; + }; + + template class tril_ { public: - tril_() : diag(diag_), diag_value(diag_value_) {} + tril_() : diag(diag_), diag_value(diag_value_::value) {} template void setup(const SUBNET& sub) { @@ -4728,13 +4733,13 @@ namespace dlib }; template - using tril = add_layer, SUBNET>; + using tril = add_layer>, SUBNET>; template - using tril_mask = add_layer::infinity()>, SUBNET>; + using tril_mask = add_layer::infinity()>>, SUBNET>; - template - using tril_diag = add_layer, SUBNET>; + template + using tril_diag = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index 463dcc5683..0c578f432b 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -3651,21 +3651,26 @@ namespace dlib // ---------------------------------------------------------------------------------------- - template + template + struct float_constant { + static constexpr T value = val; + }; + + template class tril_ { /*! REQUIREMENTS ON diag_ and diag_value_ - diag_ must be a non-negative integer. - - diag_value_ must be a floating point number. + - diag_value_::value must be a floating point number. WHAT THIS OBJECT REPRESENTS This object implements a layer in a deep neural network that applies a lower triangular mask to its input tensor. The mask is defined such that all elements above the specified diagonal are set - to a given value (diag_value_). The diagonal is specified by the diag_ parameter. + to a given value (diag_value_::value). The diagonal is specified by the diag_ parameter. EXAMPLE USAGE - tril_<0, -std::numeric_limits::infinity()> layer; + tril_<0, float_constant::infinity()>> layer; // This creates a layer that masks all elements above the main diagonal with -inf. SERIALIZATION SUPPORT @@ -3756,13 +3761,14 @@ namespace dlib }; template - using tril = add_layer, SUBNET>; + using tril = add_layer>, SUBNET>; template - using tril_mask = add_layer::infinity()>, SUBNET>; + using tril_mask = add_layer::infinity()>>, SUBNET>; + + template + using tril_diag = add_layer, SUBNET>; - template - using tril_diag = add_layer, SUBNET>; // ---------------------------------------------------------------------------------------- diff --git a/dlib/dnn/visitors.h b/dlib/dnn/visitors.h index 9f2bbd463b..e6c868233f 100644 --- a/dlib/dnn/visitors.h +++ b/dlib/dnn/visitors.h @@ -1021,15 +1021,15 @@ namespace dlib update(i); } - template - void operator()(size_t i, const add_layer, U, E>&) + template + void operator()(size_t i, const add_layer, U, E>&) { start_node(i, "tril"); out << " | {diag|{" << diag << "}}"; - out << " | {diag_value|{" << diag_value << "}}"; + out << " | {diag_value|{" << diag_value_type::value << "}}"; end_node(); update(i); - } + } template void operator()(size_t i, const add_layer&) diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index 03fc375052..1c56d81534 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -1994,7 +1994,8 @@ namespace } { print_spinner(); - tril_<-5, -std::numeric_limits::infinity()> l; + using specific_float = float_constant::infinity()>; + tril_<-5, specific_float> l; auto res = test_layer(l); DLIB_TEST_MSG(res, res); } @@ -4370,7 +4371,7 @@ void test_tril() { print_spinner(); - constexpr float NEG_INF = -std::numeric_limits::infinity(); + using NEG_INF = float_constant::infinity()>; using net_type = tag1>>>>; net_type net; @@ -4397,9 +4398,9 @@ void test_tril() expected_output.copy_size(input_tensor); tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k()); for (int ii = 0; ii < n_samples; ++ii) { - expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = NEG_INF; - expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = NEG_INF; - expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = NEG_INF; + expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits::infinity(); + expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits::infinity(); + expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits::infinity(); } // Compare output tensor with expected output