Skip to content

Commit

Permalink
Add MultiHeadAttention layer (#392)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd authored Dec 31, 2023
1 parent 71b6f7c commit 7104dd0
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Would you like to build/train a model using Keras/Python? And would you like to
* `UpSampling1D/2D`, `Resizing`
* `Reshape`, `Permute`, `RepeatVector`
* `Embedding`, `CategoryEncoding`
* `Attention`, `AdditiveAttention`
* `Attention`, `AdditiveAttention`, `MultiHeadAttention`


### Also supported
Expand All @@ -78,7 +78,7 @@ Would you like to build/train a model using Keras/Python? And would you like to
`GRUCell`, `Hashing`,
`IntegerLookup`,
`LocallyConnected1D`, `LocallyConnected2D`,
`LSTMCell`, `Masking`, `MultiHeadAttention`,
`LSTMCell`, `Masking`,
`RepeatVector`, `RNN`, `SimpleRNN`,
`SimpleRNNCell`, `StackedRNNCells`, `StringLookup`, `TextVectorization`,
`ThresholdedReLU`, `Upsampling3D`, `temporal` models
Expand Down
50 changes: 50 additions & 0 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include "fdeep/layers/maximum_layer.hpp"
#include "fdeep/layers/minimum_layer.hpp"
#include "fdeep/layers/model_layer.hpp"
#include "fdeep/layers/multi_head_attention_layer.hpp"
#include "fdeep/layers/multiply_layer.hpp"
#include "fdeep/layers/normalization_layer.hpp"
#include "fdeep/layers/pooling_3d_layer.hpp"
Expand Down Expand Up @@ -1068,6 +1069,30 @@ inline layer_ptr create_additive_attention_layer(
return std::make_shared<additive_attention_layer>(name, scale);
}

inline layer_ptr create_multi_head_attention_layer(
const get_param_f& get_param,
const nlohmann::json& data, const std::string& name)
{
const std::size_t num_heads = data["config"]["num_heads"];
const std::size_t key_dim = data["config"]["key_dim"];
const std::size_t value_dim = data["config"]["value_dim"];
const bool use_bias = data["config"]["use_bias"];
const auto weight_shapes =
create_vector<std::vector<std::size_t>>(fplus::bind_1st_of_2(
create_vector<std::size_t, decltype(create_size_t)>, create_size_t),
get_param(name, "weight_shapes"));
const auto weight_values = create_vector<float_vec>(decode_floats, get_param(name, "weights"));
const auto weights_and_biases = fplus::zip_with(
[](const std::vector<std::size_t>& shape, const float_vec& values) -> tensor
{
return tensor(
create_tensor_shape_from_dims(shape),
fplus::convert_container<float_vec>(values));
}, weight_shapes, weight_values);
return std::make_shared<multi_head_attention_layer>(name,
num_heads, key_dim, value_dim, use_bias, weights_and_biases);
}

inline std::string get_activation_type(const nlohmann::json& data)
{
assertion(data.is_string(), "Layer activation must be a string.");
Expand Down Expand Up @@ -1141,11 +1166,35 @@ inline node create_node(const nlohmann::json& inbound_nodes_data)
inbound_nodes_data));
}

inline nodes create_multi_head_attention_nodes(const std::vector<nlohmann::json> inbound_nodes_data)
{
assertion(inbound_nodes_data.size() == 1 && inbound_nodes_data.front().size() == 1,
"multi_head_attention needs to have exactly one primary inbound node; see https://stackoverflow.com/q/77400589/1866775");
const auto inbound_node_data = inbound_nodes_data.front().front();
const auto value = inbound_node_data[3]["value"];
if (json_obj_has_member(inbound_node_data[3], "key")) {
return {
node({
create_node_connection(inbound_node_data),
create_node_connection(value),
create_node_connection(inbound_node_data[3]["key"])
})};
}
return {
node({
create_node_connection(inbound_node_data),
create_node_connection(value)
})};
}

inline nodes create_nodes(const nlohmann::json& data)
{
assertion(data["inbound_nodes"].is_array(), "no inbound nodes");
const std::vector<nlohmann::json> inbound_nodes_data =
data["inbound_nodes"];
if (data["class_name"] == "MultiHeadAttention") {
return create_multi_head_attention_nodes(inbound_nodes_data);
}
return fplus::transform(create_node, inbound_nodes_data);
}

Expand Down Expand Up @@ -1378,6 +1427,7 @@ inline layer_ptr create_layer(const get_param_f& get_param,
{"CategoryEncoding", create_category_encoding_layer},
{"Attention", create_attention_layer},
{"AdditiveAttention", create_additive_attention_layer},
{"MultiHeadAttention", create_multi_head_attention_layer},
};

const wrapper_layer_creators wrapper_creators = {
Expand Down
128 changes: 128 additions & 0 deletions include/fdeep/layers/multi_head_attention_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2016, Tobias Hermann.
// https://github.com/Dobiasd/frugally-deep
// Distributed under the MIT License.
// (See accompanying LICENSE file or at
// https://opensource.org/licenses/MIT)

#pragma once

#include "fdeep/layers/layer.hpp"
#include "fdeep/layers/dense_layer.hpp"
#include "fdeep/layers/softmax_layer.hpp"

#include <string>

namespace fdeep { namespace internal
{

class multi_head_attention_layer : public layer
{
public:
explicit multi_head_attention_layer(const std::string& name,
std::size_t num_heads, std::size_t key_dim, std::size_t value_dim,
bool use_bias, const std::vector<tensor>& weights_and_biases)
: layer(name), num_heads_(num_heads), key_dim_(key_dim),
value_dim_(value_dim),
query_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 0, key_dim, name + "_query_dense")),
value_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 2, value_dim, name + "_value_dense")),
key_dense_(create_dense_layers(weights_and_biases, use_bias, num_heads, 1, key_dim, name + "_key_dense")),
output_dense_(create_output_dense_layer(weights_and_biases, use_bias, name + "_output_dense"))
{
}
private:
std::vector<dense_layer> create_dense_layers(
const tensors& weights_and_biases, bool use_bias, const std::size_t num_heads,
const std::size_t index, const std::size_t units, const std::string& name)
{
assertion(index <= 2, "Invalid dense layer index.");
const std::size_t index_factor = use_bias ? 2 : 1;
const tensor weights = weights_and_biases[index_factor * index];
const tensor biases = use_bias ?
weights_and_biases[index_factor * index + 1] :
tensor(tensor_shape(num_heads, units), 0);

assertion(weights.shape().depth_ == units, "Invalid weights shape for attention head dimension.");
assertion(biases.shape().depth_ == units, "Invalid biases shape for attention head dimension.");

const auto weights_per_head = tensor_to_tensors_width_slices(weights);
const auto biases_per_head = tensor_to_tensors_width_slices(biases);
assertion(weights_per_head.size() == num_heads, "Invalid weights for number of heads.");
assertion(biases_per_head.size() == num_heads, "Invalid biases for number of heads.");
return fplus::transform(
[&](const std::pair<std::size_t, std::pair<tensor, tensor>>& n_and_w_with_b)
{
return dense_layer(
name + "_" + std::to_string(n_and_w_with_b.first),
units,
*n_and_w_with_b.second.first.as_vector(),
*n_and_w_with_b.second.second.as_vector());
},
fplus::enumerate(fplus::zip(weights_per_head, biases_per_head)));
}
dense_layer create_output_dense_layer(
const tensors& weights_and_biases, bool use_bias, const std::string& name)
{
const std::size_t index_factor = use_bias ? 2 : 1;
const tensor weights = weights_and_biases[index_factor * 3];
const std::size_t units = weights.shape().depth_;
const tensor biases = use_bias ?
weights_and_biases[index_factor * 3 + 1] :
tensor(tensor_shape(units), 0);
return dense_layer(name + "_output", units, *weights.as_vector(), *biases.as_vector());
}
tensors extract_biases(const tensors& saved_weights, bool use_bias)
{
return use_bias ? fplus::unweave(saved_weights).second : tensors();
}
tensor apply_head(
const tensor& query_raw,
const tensor& value_raw,
const tensor& key_raw,
std::size_t head_index) const
{
assertion(
query_raw.shape().rank() == 2 &&
value_raw.shape().rank() == 2 &&
key_raw.shape().rank() == 2 &&
query_raw.shape().depth_ == value_raw.shape().depth_ &&
query_raw.shape().depth_ == key_raw.shape().depth_ &&
value_raw.shape().width_ == key_raw.shape().width_,
"Invalid shapes; need a query tensor of shape (B, T, dim) and a value/key tensor of shape (B, S, dim)."
);
const tensor query = query_dense_[head_index].apply({query_raw}).front();
const tensor value = value_dense_[head_index].apply({value_raw}).front();
const tensor key = key_dense_[head_index].apply({key_raw}).front();

// https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853
// https://dmol.pub/dl/attention.html#multi-head-attention-block
// https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/attention/multi_head_attention.py
// https://gist.github.com/sevagh/b71d253a347a9b59c026580625452fc5
const tensor scores = dot_product_tensors(query, transpose(key), std::vector<int>({2, 1}), false);
const std::size_t query_size = query.shape().depth_;
const tensor distribution = softmax(transform_tensor(fplus::multiply_with(1 / std::sqrt(query_size)), scores));
return dot_product_tensors(distribution, value, std::vector<int>({2, 1}), false);
}
protected:
tensors apply_impl(const tensors& input) const override
{
assertion(input.size() == 2 || input.size() == 3, "Invalid number of inputs for MultiHeadAttention layer.");
const tensor query_raw = input[0];
const tensor value_raw = input[1];
const tensor key_raw = input.size() > 2 ? input[2] : value_raw;
const auto outputs = fplus::transform([&](const std::size_t head_idx)
{
return apply_head(query_raw, value_raw, key_raw, head_idx);
}, fplus::numbers<std::size_t>(0, num_heads_));
const tensor merged = concatenate_tensors_depth(outputs);
return output_dense_.apply({merged});
}
std::size_t num_heads_;
std::size_t key_dim_;
std::size_t value_dim_;
std::vector<dense_layer> query_dense_;
std::vector<dense_layer> value_dense_;
std::vector<dense_layer> key_dense_;
dense_layer output_dense_;
};

} } // namespace fdeep, namespace internal
13 changes: 13 additions & 0 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,18 @@ def show_additive_attention_layer(layer):
return data


def show_multi_head_attention_layer(layer):
"""Serialize MultiHeadAttention layer to dict"""
assert len(layer.input_shape) == 3
assert layer.input_shape[0] is None
assert layer._output_shape is None
assert layer._attention_axes == (1,), "MultiHeadAttention supported only with attention_axes=None"
return {
'weight_shapes': list(map(lambda w: list(w.shape), layer.weights)),
'weights': list(map(lambda w: encode_floats(w.numpy()), layer.weights)),
}


def get_layer_functions_dict():
return {
'Conv1D': show_conv_1d_layer,
Expand Down Expand Up @@ -588,6 +600,7 @@ def get_layer_functions_dict():
'CategoryEncoding': show_category_encoding_layer,
'Attention': show_attention_layer,
'AdditiveAttention': show_additive_attention_layer,
'MultiHeadAttention': show_multi_head_attention_layer,
}


Expand Down
35 changes: 35 additions & 0 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tensorflow.keras.layers import MaxPooling1D, AveragePooling1D, UpSampling1D
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, UpSampling2D
from tensorflow.keras.layers import MaxPooling3D, AveragePooling3D
from tensorflow.keras.layers import MultiHeadAttention
from tensorflow.keras.layers import Multiply, Add, Subtract, Average, Maximum, Minimum, Dot
from tensorflow.keras.layers import Permute, Reshape, RepeatVector
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
Expand Down Expand Up @@ -435,6 +436,40 @@ def get_test_model_exhaustive():
outputs.append(AdditiveAttention(use_scale=True)([inputs[49], inputs[50]]))
outputs.append(AdditiveAttention(use_scale=True)([inputs[49], inputs[50], inputs[51]]))

outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=None,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=2, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=2, value_dim=None,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=2,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=2,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=3, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=3, key_dim=1, value_dim=None,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50]))
outputs.append(MultiHeadAttention(
num_heads=1, key_dim=1, value_dim=None,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))
outputs.append(MultiHeadAttention(
num_heads=2, key_dim=3, value_dim=5,
use_bias=False, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))
outputs.append(MultiHeadAttention(
num_heads=2, key_dim=3, value_dim=5,
use_bias=True, output_shape=None, attention_axes=None)(inputs[49], inputs[50], inputs[51]))

shared_conv = Conv2D(1, (1, 1),
padding='valid', name='shared_conv', activation='relu')

Expand Down

0 comments on commit 7104dd0

Please sign in to comment.