Skip to content

Commit

Permalink
Removing 'float' from the template to make it C++17 compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
Cydral authored Sep 23, 2024
1 parent 078c9cb commit cf53441
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 23 deletions.
17 changes: 11 additions & 6 deletions dlib/dnn/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4637,11 +4637,16 @@ namespace dlib

// ----------------------------------------------------------------------------------------

template <long diag_, float diag_value_>
template <typename T, T val>
struct float_constant {
static constexpr T value = val;
};

template <long diag_, typename diag_value_>
class tril_
{
public:
tril_() : diag(diag_), diag_value(diag_value_) {}
tril_() : diag(diag_), diag_value(diag_value_::value) {}

template <typename SUBNET>
void setup(const SUBNET& sub) {
Expand Down Expand Up @@ -4728,13 +4733,13 @@ namespace dlib
};

template <typename SUBNET>
using tril = add_layer<tril_<0, 0.0f>, SUBNET>;
using tril = add_layer<tril_<0, float_constant<float, 0.0f>>, SUBNET>;

template <typename SUBNET>
using tril_mask = add_layer<tril_<0, -std::numeric_limits<float>::infinity()>, SUBNET>;
using tril_mask = add_layer<tril_<0, float_constant<float, -std::numeric_limits<float>::infinity()>>, SUBNET>;

template <long diag, float diag_value, typename SUBNET>
using tril_diag = add_layer<tril_<diag, diag_value>, SUBNET>;
template <long diag, typename diag_value_type, typename SUBNET>
using tril_diag = add_layer<tril_<diag, diag_value_type>, SUBNET>;

// ----------------------------------------------------------------------------------------

Expand Down
22 changes: 14 additions & 8 deletions dlib/dnn/layers_abstract.h
Original file line number Diff line number Diff line change
Expand Up @@ -3651,21 +3651,26 @@ namespace dlib

// ----------------------------------------------------------------------------------------

template <long diag_, float diag_value_>
template <typename T, T val>
struct float_constant {
static constexpr T value = val;
};

template <long diag_, typename diag_value_>
class tril_
{
/*!
REQUIREMENTS ON diag_ and diag_value_
- diag_ must be a non-negative integer.
- diag_value_ must be a floating point number.
- diag_value_::value must be a floating point number.
WHAT THIS OBJECT REPRESENTS
This object implements a layer in a deep neural network that applies a lower triangular mask to
its input tensor. The mask is defined such that all elements above the specified diagonal are set
to a given value (diag_value_). The diagonal is specified by the diag_ parameter.
to a given value (diag_value_::value). The diagonal is specified by the diag_ parameter.
EXAMPLE USAGE
tril_<0, -std::numeric_limits<float>::infinity()> layer;
tril_<0, float_constant<float, -std::numeric_limits<float>::infinity()>> layer;
// This creates a layer that masks all elements above the main diagonal with -inf.
SERIALIZATION SUPPORT
Expand Down Expand Up @@ -3756,13 +3761,14 @@ namespace dlib
};

template <typename SUBNET>
using tril = add_layer<tril_<0, 0.0f>, SUBNET>;
using tril = add_layer<tril_<0, float_constant<float, 0.0f>>, SUBNET>;

template <typename SUBNET>
using tril_mask = add_layer<tril_<0, -std::numeric_limits<float>::infinity()>, SUBNET>;
using tril_mask = add_layer<tril_<0, float_constant<float, -std::numeric_limits<float>::infinity()>>, SUBNET>;

template <long diag, typename diag_value_type, typename SUBNET>
using tril_diag = add_layer<tril_<diag, diag_value_type>, SUBNET>;

template <long diag, float diag_value, typename SUBNET>
using tril_diag = add_layer<tril_<diag, diag_value>, SUBNET>;

// ----------------------------------------------------------------------------------------

Expand Down
8 changes: 4 additions & 4 deletions dlib/dnn/visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -1021,15 +1021,15 @@ namespace dlib
update(i);
}

template <long diag, float diag_value, typename U, typename E>
void operator()(size_t i, const add_layer<tril_<diag, diag_value>, U, E>&)
template <long diag, typename diag_value_type, typename U, typename E>
void operator()(size_t i, const add_layer<tril_<diag, diag_value_type>, U, E>&)
{
start_node(i, "tril");
out << " | {diag|{" << diag << "}}";
out << " | {diag_value|{" << diag_value << "}}";
out << " | {diag_value|{" << diag_value_type::value << "}}";
end_node();
update(i);
}
}

template <typename T, typename U, typename E>
void operator()(size_t i, const add_layer<T, U, E>&)
Expand Down
11 changes: 6 additions & 5 deletions dlib/test/dnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1994,7 +1994,8 @@ namespace
}
{
print_spinner();
tril_<-5, -std::numeric_limits<float>::infinity()> l;
using specific_float = float_constant<float, -std::numeric_limits<float>::infinity()>;
tril_<-5, specific_float> l;
auto res = test_layer(l);
DLIB_TEST_MSG(res, res);
}
Expand Down Expand Up @@ -4370,7 +4371,7 @@ void test_tril()
{
print_spinner();

constexpr float NEG_INF = -std::numeric_limits<float>::infinity();
using NEG_INF = float_constant<float, -std::numeric_limits<float>::infinity()>;
using net_type = tag1<tril_diag<0, NEG_INF, tag2<input<matrix<float>>>>>;
net_type net;

Expand All @@ -4397,9 +4398,9 @@ void test_tril()
expected_output.copy_size(input_tensor);
tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k());
for (int ii = 0; ii < n_samples; ++ii) {
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = NEG_INF;
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = NEG_INF;
expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = NEG_INF;
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits<float>::infinity();
expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits<float>::infinity();
}

// Compare output tensor with expected output
Expand Down

0 comments on commit cf53441

Please sign in to comment.