diff --git a/dlib/dnn/layers.h b/dlib/dnn/layers.h index 74de214ee0..47909ae4ca 100644 --- a/dlib/dnn/layers.h +++ b/dlib/dnn/layers.h @@ -4790,6 +4790,132 @@ namespace dlib template using transpose = add_layer; +// ---------------------------------------------------------------------------------------- + + struct neg_infinity_tag {}; + struct zero_tag {}; + + template + struct is_special_value : std::false_type {}; + template<> + struct is_special_value : std::true_type {}; + template<> + struct is_special_value : std::true_type {}; + + template + class tril_ + { + public: + tril_(): diag(diag_), diag_value(compute_diag_value()) {} + + template + void setup(const SUBNET& /*sub*/) + { + } + + template + void forward(const SUBNET& sub, resizable_tensor& output) + { + auto& prev = sub.get_output(); + output.set_size(prev.num_samples(), prev.k(), prev.nr(), prev.nc()); + + check_mask(prev); + tt::multiply(false, output, prev, binary_mask); + if (diag_value != 0.0f) tt::add(1, output, 1, output_mask); + } + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/) + { + auto& prev_grad = sub.get_gradient_input(); + tt::multiply(true, prev_grad, gradient_input, binary_mask); + } + + inline dpoint map_input_to_output(const dpoint& p) const { return p; } + inline dpoint map_output_to_input(const dpoint& p) const { return p; } + + const tensor& get_layer_params() const { return params; } + tensor& get_layer_params() { return params; } + + friend void serialize(const tril_& item, std::ostream& out) + { + serialize("tril_", out); + serialize(item.diag, out); + serialize(item.diag_value, out); + } + friend void deserialize(tril_& item, std::istream& in) + { + std::string version; + deserialize(version, in); + if (version != "tril_") + throw serialization_error("Unexpected version '" + version + "' found while deserializing dlib::tril_."); + deserialize(item.diag, in); + deserialize(item.diag_value, in); + } + + friend std::ostream& operator<<(std::ostream& out, const tril_& item) + { + out << "tril (diag=" << item.diag << ", diag_value=" << item.diag_value << ")"; + return out; + } + friend void to_xml(const tril_& item, std::ostream& out) + { + out << "\n"; + } + + private: + float compute_diag_value() const { + if (std::is_same::value) + return -std::numeric_limits::infinity(); + else if (std::is_same::value) + return 0.0f; + else + return static_cast(num_) / static_cast(den_); + } + + void check_mask(const tensor& t) + { + if (!have_same_dimensions(binary_mask, t)) { + binary_mask.copy_size(t); + binary_mask = 1; + if (diag_value != 0.0f) { + output_mask.copy_size(t); + output_mask = 0; + } + for (long s = 0; s < output_mask.num_samples(); ++s) + { + for (long k = 0; k < output_mask.k(); ++k) + { + for (long r = 0; r < output_mask.nr(); ++r) + { + for (long c = std::max(r + diag + 1, 0L); c < output_mask.nc(); ++c) + { + if (diag_value != 0.0f) output_mask.host()[tensor_index(output_mask, s, k, r, c)] = diag_value; + binary_mask.host()[tensor_index(binary_mask, s, k, r, c)] = 0; + } + } + } + } + } + } + + template + struct always_false : std::false_type {}; + + resizable_tensor params; // unused + resizable_tensor binary_mask, output_mask; + long diag; + float diag_value; + }; + + template + using tril = add_layer, SUBNET>; + + template + using tril_mask = add_layer, SUBNET>; + + template + using tril_diag = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- class positional_encodings_ { diff --git a/dlib/dnn/layers_abstract.h b/dlib/dnn/layers_abstract.h index d598c6efa0..d106b9e61f 100644 --- a/dlib/dnn/layers_abstract.h +++ b/dlib/dnn/layers_abstract.h @@ -3790,6 +3790,162 @@ namespace dlib template using transpose = add_layer; +// ---------------------------------------------------------------------------------------- + + struct neg_infinity_tag {}; + struct zero_tag {}; + + template + struct is_special_value : std::false_type {}; + template<> + struct is_special_value : std::true_type {}; + template<> + struct is_special_value : std::true_type {}; + + template + class tril_ + { + /*! + TEMPLATE PARAMETERS + - diag_: A long integer specifying the diagonal offset. + - tag_: A type tag specifying special values or void for numeric values. + - num_: Numerator for numeric diagonal value (default is 0, only used if tag_ is void). + - den_: Denominator for numeric diagonal value (default is 1, only used if tag_ is void). + + REQUIREMENTS + - diag_ must be an integer. + - tag_ must be either neg_infinity_tag, zero_tag, or void. + - If tag_ is void, num_ and den_ are used to compute the diagonal value. + - If tag_ is neg_infinity_tag or zero_tag, num_ and den_ are ignored. + + WHAT THIS OBJECT REPRESENTS + This object implements a layer in a deep neural network that applies a lower triangular mask to + its input tensor. The mask is defined such that all elements above the specified diagonal are set + to a given value. The diagonal offset and the mask value are determined by the template parameters. + + DIAGONAL VALUE DETERMINATION + - If tag_ is neg_infinity_tag: diagonal value is set to negative infinity. + - If tag_ is zero_tag: diagonal value is set to zero. + - If tag_ is void: diagonal value is set to num_ / den_ as a float. + + DIAGONAL OFFSET + The diag_ parameter determines the diagonal above which elements are masked: + - diag_ = 0: main diagonal + - diag_ > 0: diag_ steps above the main diagonal + - diag_ < 0: |diag_| steps below the main diagonal + + EXAMPLE USAGE + // Create a layer that masks all elements above the main diagonal with -inf + tril_<0, neg_infinity_tag> layer1; + + // Create a layer that masks all elements above the main diagonal with 0 + tril_<0, zero_tag> layer2; + + // Create a layer that masks all elements above the main diagonal with 0.5 + tril_<0, void, 1, 2> layer3; + + // Create a layer that masks all elements 5 positions above the main diagonal with -inf + tril_<5, neg_infinity_tag> layer4; + + // Create a layer that masks all elements 3 positions below the main diagonal with 0.25 + tril_<-3, void, 1, 4> layer5; + + SERIALIZATION SUPPORT + This object supports serialization and deserialization via the serialize() and deserialize() functions. + !*/ + + public: + tril_() = default; + /*! + ensures + - This object is properly initialized. + !*/ + + template + void setup(const SUBNET& sub); + /*! + requires + - SUBNET is a valid network layer type. + ensures + - Initializes the mask based on the dimensions of the input tensor from sub. + !*/ + + template + void forward(const SUBNET& sub, resizable_tensor& output); + /*! + requires + - SUBNET is a valid network layer type. + ensures + - Applies the lower triangular mask to the input tensor from sub and stores the result in output. + !*/ + + template + void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad); + /*! + requires + - SUBNET is a valid network layer type. + ensures + - Computes the gradient of the loss with respect to the input tensor and stores it in sub. + !*/ + + inline dpoint map_input_to_output(const dpoint& p) const; + /*! + ensures + - Maps a point from the input tensor to the corresponding point in the output tensor. + !*/ + + inline dpoint map_output_to_input(const dpoint& p) const; + /*! + ensures + - Maps a point from the output tensor to the corresponding point in the input tensor. + !*/ + + const tensor& get_layer_params() const; + /*! + ensures + - Returns the parameters of this layer. + !*/ + + tensor& get_layer_params(); + /*! + ensures + - Returns the parameters of this layer. + !*/ + + friend void serialize(const tril_& item, std::ostream& out); + /*! + ensures + - Serializes the state of this object to the given output stream. + !*/ + + friend void deserialize(tril_& item, std::istream& in); + /*! + ensures + - Deserializes the state of this object from the given input stream. + !*/ + + friend std::ostream& operator<<(std::ostream& out, const tril_& item); + /*! + ensures + - Prints a human-readable representation of this object to the given output stream. + !*/ + + friend void to_xml(const tril_& item, std::ostream& out); + /*! + ensures + - Serializes the state of this object to XML format and writes it to the given output stream. + !*/ + }; + + template + using tril = add_layer, SUBNET>; + + template + using tril_mask = add_layer, SUBNET>; + + template + using tril_diag = add_layer, SUBNET>; + // ---------------------------------------------------------------------------------------- class positional_encodings_ diff --git a/dlib/dnn/visitors.h b/dlib/dnn/visitors.h index 9043a57e14..7834d103e0 100644 --- a/dlib/dnn/visitors.h +++ b/dlib/dnn/visitors.h @@ -1039,6 +1039,22 @@ namespace dlib update(i); } + template + void operator()(size_t i, const add_layer, U, E>&) + { + start_node(i, "tril"); + out << " | {diag|{" << diag << "}}"; + out << " | {diag_value|{"; + + if (std::is_same::value) out << "-inf"; + else if (std::is_same::value) out << "0"; + else out << static_cast(num) / static_cast(den); + + out << "}}"; + end_node(); + update(i); + } + template void operator()(size_t i, const add_layer&) { diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index 6ae4328760..f13a4b31f6 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -2056,6 +2056,12 @@ void test_positional_encodings() auto res = test_layer(l); DLIB_TEST_MSG(res, res); } + { + print_spinner(); + tril_<-5, void, 1, 2> l; + auto res = test_layer(l); + DLIB_TEST_MSG(res, res); + } { print_spinner(); extract_<0,2,2,2> l; @@ -4533,6 +4539,47 @@ void test_multm_prev() } } +// ---------------------------------------------------------------------------------------- + + void test_tril() + { + print_spinner(); + using net_type = tag1>>>>; + net_type net; + + // Input tensor + dlib::rand rnd; + const int nr = 2, nc = 3; + constexpr int n_samples = 3, k = 1; + std::vector> x(n_samples); + matrix xtmp(nr, nc); + for (int ii = 0; ii < n_samples; ++ii) { + for (int jj = 0; jj < nr; ++jj) + for (int kk = 0; kk < nc; ++kk) + xtmp(jj, kk) = rnd.get_random_gaussian(); + x[ii] = xtmp; + } + + // Convert input matrix to tensor + resizable_tensor input_tensor; + net.to_tensor(&x[0], &x[0] + n_samples, input_tensor); + net.forward(input_tensor); + + // Expected output tensor (manually set for comparison) + resizable_tensor expected_output; + expected_output.copy_size(input_tensor); + tt::copy_tensor(false, expected_output, 0, input_tensor, 0, input_tensor.k()); + for (int ii = 0; ii < n_samples; ++ii) { + expected_output.host()[tensor_index(expected_output, ii, 0, 0, 1)] = -std::numeric_limits::infinity(); + expected_output.host()[tensor_index(expected_output, ii, 0, 0, 2)] = -std::numeric_limits::infinity(); + expected_output.host()[tensor_index(expected_output, ii, 0, 1, 2)] = -std::numeric_limits::infinity(); + } + + // Compare output tensor with expected output + auto& net_output = layer(net).get_output(); + DLIB_TEST(max(abs(mat(net_output) - mat(expected_output))) < 1e-5); + } + // ---------------------------------------------------------------------------------------- class dnn_tester : public tester @@ -4613,6 +4660,7 @@ void test_multm_prev() test_layer_normalize(); test_rms_normalize(); test_transpose(); + test_tril(); test_positional_encodings(); test_basic_tensor_ops(); test_layers();