Skip to content

Commit

Permalink
Merge pull request #355 from darioizzo/ffnn
Browse files Browse the repository at this point in the history
ffnn first impl
  • Loading branch information
bluescarni authored Nov 3, 2023
2 parents ec51518 + b93e68c commit 5ce6a9f
Show file tree
Hide file tree
Showing 7 changed files with 414 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ set(HEYOKA_SRC_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/src/model/mascon.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/model/vsop2013.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/model/cr3bp.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/model/ffnn.cpp"
# Math functions.
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/kepE.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/math/kepF.cpp"
Expand Down
7 changes: 7 additions & 0 deletions include/heyoka/kw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ IGOR_MAKE_NAMED_ARGUMENT(parallel_mode);
IGOR_MAKE_NAMED_ARGUMENT(prec);
IGOR_MAKE_NAMED_ARGUMENT(mu);

// kwargs for the ffnn
IGOR_MAKE_NAMED_ARGUMENT(inputs);
IGOR_MAKE_NAMED_ARGUMENT(nn_hidden);
IGOR_MAKE_NAMED_ARGUMENT(n_out);
IGOR_MAKE_NAMED_ARGUMENT(activations);
IGOR_MAKE_NAMED_ARGUMENT(nn_wb);

} // namespace kw

HEYOKA_END_NAMESPACE
Expand Down
155 changes: 155 additions & 0 deletions include/heyoka/model/ffnn.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright 2020, 2021, 2022, 2023 Francesco Biscani ([email protected]), Dario Izzo ([email protected])
//
// This file is part of the heyoka library.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef HEYOKA_MODEL_FFNN_HPP
#define HEYOKA_MODEL_FFNN_HPP

#include <cstdint>
#include <functional>
#include <tuple>
#include <utility>
#include <vector>

#include <boost/numeric/conversion/cast.hpp>
#include <boost/safe_numerics/safe_integer.hpp>

#include <heyoka/config.hpp>
#include <heyoka/detail/igor.hpp>
#include <heyoka/detail/type_traits.hpp>
#include <heyoka/detail/visibility.hpp>
#include <heyoka/expression.hpp>
#include <heyoka/kw.hpp>

HEYOKA_BEGIN_NAMESPACE

namespace model
{
namespace detail
{

template <typename... KwArgs>
auto ffnn_common_opts(const KwArgs &...kw_args)
{
igor::parser p{kw_args...};

static_assert(!p.has_unnamed_arguments(), "This function accepts only named arguments");

// Network inputs. Mandatory.
// The kw::inputs argument must be a range of values from which
// an expression can be constructed.
std::vector<expression> inputs;
if constexpr (p.has(kw::inputs)) {
for (const auto &val : p(kw::inputs)) {
inputs.emplace_back(val);
}
} else {
static_assert(heyoka::detail::always_false_v<KwArgs...>,
"The 'inputs' keyword argument is necessary but it was not provided");
}

// Number of hidden neurons per hidden layer. Mandatory.
// The kw::nn_hidden argument must be a range containing
// integral values.
std::vector<std::uint32_t> nn_hidden;
if constexpr (p.has(kw::nn_hidden)) {
for (const auto &nval : p(kw::nn_hidden)) {
nn_hidden.push_back(boost::numeric_cast<std::uint32_t>(nval));
}
} else {
static_assert(heyoka::detail::always_false_v<KwArgs...>,
"The 'nn_hidden' keyword argument is necessary but it was not provided");
}

// Number of network outputs. Mandatory.
// The kw::n_out argument must be of integral type.
auto n_out = [&p]() {
if constexpr (p.has(kw::n_out)) {
return boost::numeric_cast<std::uint32_t>(p(kw::n_out));
} else {
static_assert(heyoka::detail::always_false_v<KwArgs...>,
"The 'n_out' keyword argument is necessary but it was not provided");
}
}();

// Network activation functions. Mandatory.
// The kw::activations argument must be a range containing values
// from which a std::function can be constructed.
std::vector<std::function<expression(const expression &)>> activations;
if constexpr (p.has(kw::activations)) {
for (const auto &f : p(kw::activations)) {
activations.emplace_back(f);
}
} else {
static_assert(heyoka::detail::always_false_v<KwArgs...>,
"The 'activations' keyword argument is necessary but it was not provided");
}

// Network weights and biases. Optional, defaults to heyoka parameters.
// The kw::nn_wb argument, if present, must be a range of values from which
// expressions can be constructed.
std::vector<expression> nn_wb;
if constexpr (p.has(kw::nn_wb)) {
for (const auto &val : p(kw::nn_wb)) {
nn_wb.emplace_back(val);
}
} else {
// Safe counterpart to std::uint32_t in order to avoid
// overflows when manipulating indices and sizes.
using su32 = boost::safe_numerics::safe<std::uint32_t>;

// Number of hidden layers (defined as all neuronal columns that are nor input nor output neurons).
auto n_hidden_layers = su32(nn_hidden.size());
// Number of neuronal layers (counting input and output).
auto n_layers = n_hidden_layers + 2;
// Number of inputs.
auto n_in = su32(inputs.size());
// Number of neurons per neuronal layer.
std::vector<su32> n_neurons{n_in};
n_neurons.insert(n_neurons.end(), nn_hidden.begin(), nn_hidden.end());
n_neurons.insert(n_neurons.end(), n_out);

// Number of network parameters (wb: weights and biases, w: only weights).
su32 n_wb = 0;
for (su32 i = 1; i < n_layers; ++i) {
n_wb += n_neurons[i - 1] * n_neurons[i];
n_wb += n_neurons[i];
}
nn_wb.resize(n_wb);
for (decltype(nn_wb.size()) i = 0; i < nn_wb.size(); ++i) {
nn_wb[i] = par[boost::numeric_cast<std::uint32_t>(i)];
}
}

return std::tuple{std::move(inputs), std::move(nn_hidden), std::move(n_out), std::move(activations),
std::move(nn_wb)};
}

// This c++ function returns the symbolic expressions of the `n_out` output neurons in a feed forward neural network,
// as a function of the `n_in` input expressions.
//
// The expression will contain the weights and biases of the neural network flattened into `pars` with the following
// conventions:
//
// from the left to right layer of parameters: [W01, W12,W23, ..., B1,B2,B3,....] where the weight matrices Wij are
// to be considered as flattened (row first) and so are the bias vectors.
//
HEYOKA_DLL_PUBLIC std::vector<expression> ffnn_impl(const std::vector<expression> &, const std::vector<std::uint32_t> &,
std::uint32_t,
const std::vector<std::function<expression(const expression &)>> &,
const std::vector<expression> &);
} // namespace detail

inline constexpr auto ffnn = [](const auto &...kw_args) -> std::vector<expression> {
return std::apply(detail::ffnn_impl, detail::ffnn_common_opts(kw_args...));
};

} // namespace model

HEYOKA_END_NAMESPACE

#endif
1 change: 1 addition & 0 deletions include/heyoka/models.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define HEYOKA_MODELS_HPP

#include <heyoka/model/cr3bp.hpp>
#include <heyoka/model/ffnn.hpp>
#include <heyoka/model/fixed_centres.hpp>
#include <heyoka/model/mascon.hpp>
#include <heyoka/model/nbody.hpp>
Expand Down
135 changes: 135 additions & 0 deletions src/model/ffnn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Copyright 2020, 2021, 2022, 2023 Francesco Biscani ([email protected]), Dario Izzo ([email protected])
//
// This file is part of the heyoka library.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <functional>
#include <stdexcept>
#include <vector>

#include <boost/safe_numerics/safe_integer.hpp>

#include <fmt/core.h>

#include <heyoka/config.hpp>
#include <heyoka/expression.hpp>
#include <heyoka/model/ffnn.hpp>

HEYOKA_BEGIN_NAMESPACE

namespace model::detail
{

namespace
{

// Safe counterpart to std::uint32_t in order to avoid
// overflows when manipulating indices and sizes.
using su32 = boost::safe_numerics::safe<std::uint32_t>;

std::vector<expression> compute_layer(su32 layer_id, const std::vector<expression> &inputs,
const std::vector<su32> &n_neurons,
const std::function<expression(const expression &)> &activation,
const std::vector<expression> &nn_wb, su32 n_net_w, su32 &wcounter,
su32 &bcounter)
{
assert(layer_id > 0u);
auto n_neurons_prev_layer = su32(inputs.size());
auto n_neurons_curr_layer = n_neurons[layer_id];

std::vector<expression> retval(static_cast<std::vector<expression>::size_type>(n_neurons_curr_layer), 0_dbl);
for (su32 i = 0; i < n_neurons_curr_layer; ++i) {
for (su32 j = 0; j < n_neurons_prev_layer; ++j) {

// Add the weight and update the weight counter.
retval[i] += nn_wb[wcounter] * inputs[j];
++wcounter;
}

// Add the bias and update the counter.
retval[i] += nn_wb[bcounter + n_net_w];
++bcounter;
// Activation function.
retval[i] = activation(retval[i]);
}
return retval;
}

} // namespace

std::vector<expression> ffnn_impl(const std::vector<expression> &in, const std::vector<std::uint32_t> &nn_hidden,
std::uint32_t n_out,
const std::vector<std::function<expression(const expression &)>> &activations,
const std::vector<expression> &nn_wb)
{
// Sanity checks.
if (activations.empty()) {
throw std::invalid_argument("Cannot create a FFNN with an empty list of activation functions");
}
if (nn_hidden.size() != activations.size() - 1u) {
throw std::invalid_argument(fmt::format(
"The number of hidden layers, as detected from the inputs, was {}, while "
"the number of activation function supplied was {}. A FFNN needs exactly one more activation function "
"than the number of hidden layers.",
nn_hidden.size(), activations.size()));
}
if (in.empty()) {
throw std::invalid_argument("The inputs provided to the FFNN is an empty vector.");
}
if (n_out == 0u) {
throw std::invalid_argument("The number of network outputs cannot be zero.");
}
if (!std::all_of(nn_hidden.begin(), nn_hidden.end(), [](auto item) { return item > 0u; })) {
throw std::invalid_argument("The number of neurons for each hidden layer must be greater than zero!");
}
if (std::any_of(activations.begin(), activations.end(), [](const auto &func) { return !func; })) {
throw std::invalid_argument("The list of activation functions cannot contain empty functions");
}

// From now on, always use safe arithmetics to compute/manipulate
// indices and sizes.
using detail::su32;

// Number of hidden layers (defined as all neuronal columns that are nor input nor output neurons).
auto n_hidden_layers = su32(nn_hidden.size());
// Number of neuronal layers (counting input and output).
auto n_layers = n_hidden_layers + 2;
// Number of inputs.
auto n_in = su32(in.size());
// Number of neurons per neuronal layer.
std::vector<su32> n_neurons{n_in};
n_neurons.insert(n_neurons.end(), nn_hidden.begin(), nn_hidden.end());
n_neurons.insert(n_neurons.end(), n_out);
// Number of network parameters (wb: weights and biases, w: only weights).
su32 n_net_wb = 0, n_net_w = 0;
for (su32 i = 1; i < n_layers; ++i) {
n_net_wb += n_neurons[i - 1u] * n_neurons[i];
n_net_w += n_neurons[i - 1u] * n_neurons[i];
n_net_wb += n_neurons[i];
}
// Sanity check.
if (nn_wb.size() != n_net_wb) {
throw std::invalid_argument(fmt::format(
"The number of network parameters, detected from its structure to be {}, does not match the size of "
"the corresponding expressions: {}.",
static_cast<std::uint32_t>(n_net_wb), nn_wb.size()));
}

// Now we build the expressions recursively transvering from layer to layer (L = f(Wx+b))).
std::vector<expression> retval = in;
su32 wcounter = 0, bcounter = 0;
for (su32 i = 1; i < n_layers; ++i) {
retval = detail::compute_layer(i, retval, n_neurons, activations[i - 1u], nn_wb, n_net_w, wcounter, bcounter);
}
return retval;
}

} // namespace model::detail

HEYOKA_END_NAMESPACE
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ ADD_HEYOKA_TESTCASE(model_fixed_centres)
ADD_HEYOKA_TESTCASE(model_rotating)
ADD_HEYOKA_TESTCASE(model_mascon)
ADD_HEYOKA_TESTCASE(model_cr3bp)
ADD_HEYOKA_TESTCASE(model_ffnn)
ADD_HEYOKA_TESTCASE(step_callback)
ADD_HEYOKA_TESTCASE(llvm_state_mem_cache)

Expand Down
Loading

0 comments on commit 5ce6a9f

Please sign in to comment.