Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#3127 Add log sum exp func #3131

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions stan/math/fwd/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
#include <stan/math/fwd/fun/log_rising_factorial.hpp>
#include <stan/math/fwd/fun/log_softmax.hpp>
#include <stan/math/fwd/fun/log_sum_exp.hpp>
#include <stan/math/fwd/fun/log_add_exp.hpp>
#include <stan/math/fwd/fun/logit.hpp>
#include <stan/math/fwd/fun/mdivide_left.hpp>
#include <stan/math/fwd/fun/mdivide_left_ldlt.hpp>
Expand Down
162 changes: 162 additions & 0 deletions stan/math/fwd/fun/log_add_exp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#ifndef STAN_MATH_FWD_FUN_LOG_ADD_EXP_HPP
#define STAN_MATH_FWD_FUN_LOG_ADD_EXP_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/log_add_exp.hpp>
#include <cmath>
#include <vector>

namespace stan {
namespace math {

// Overload for fvar and fvar
template <typename T>
inline fvar<T> log_add_exp(const fvar<T>& x1, const fvar<T>& x2) {
auto val = stan::math::log_add_exp(x1.val_, x2.val_);

auto exp_x1 = stan::math::exp(x1.val_);
auto exp_x2 = stan::math::exp(x2.val_);
auto sum_exp = exp_x1 + exp_x2;

auto grad1 = exp_x1 / sum_exp;
auto grad2 = exp_x2 / sum_exp;

return fvar<T>(val, x1.d_ * grad1 + x2.d_ * grad2);
}

template <typename T>
inline fvar<T> log_add_exp(const fvar<T>& x1, double x2) {
if (x1.val_ == NEGATIVE_INFTY) {
return fvar<T>(x2, 0.0); // log_add_exp(-∞, b) = b
}
return log_add_exp(x2, x1);
}

template <typename T>
inline fvar<T> log_add_exp(double x1, const fvar<T>& x2) {
if (x2.val_ == NEGATIVE_INFTY) {
return fvar<T>(x1, 0.0); // log_add_exp(a, -∞) = a
}
auto val = stan::math::log_add_exp(x1, x2.val_);
auto exp_x2 = stan::math::exp(x2.val_);
auto grad = exp_x2 / (stan::math::exp(x1) + exp_x2);
return fvar<T>(val, x2.d_ * grad);
}

// Overload for matrices of fvar
template <typename T>
inline Eigen::Matrix<fvar<T>, -1, -1> log_add_exp(
const Eigen::Matrix<fvar<T>, -1, -1>& a,
const Eigen::Matrix<fvar<T>, -1, -1>& b) {
using fvar_mat_type = Eigen::Matrix<fvar<T>, -1, -1>;
fvar_mat_type result(a.rows(), a.cols());

// Check for empty inputs
if (a.size() == 0 || b.size() == 0) {
throw std::invalid_argument("Input containers must not be empty.");
}

// Check for NaN
if (a.array().isNaN().any() || b.array().isNaN().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Check for infinity
if (a.array().isInf().any() || b.array().isInf().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Apply the log_add_exp operation directly
for (int i = 0; i < a.rows(); ++i) {
for (int j = 0; j < a.cols(); ++j) {
result(i, j) = stan::math::log_add_exp(a(i, j), b(i, j));
}
}

return result; // Return the result matrix
}

// Overload for Eigen vectors
template <typename T>
inline Eigen::Matrix<fvar<T>, -1, 1> log_add_exp(
const Eigen::Matrix<fvar<T>, -1, 1>& a,
const Eigen::Matrix<fvar<T>, -1, 1>& b) {
using fvar_vec_type = Eigen::Matrix<fvar<T>, -1, 1>;
fvar_vec_type result(a.rows());

// Check for empty inputs
if (a.size() == 0 || b.size() == 0) {
throw std::invalid_argument("Input containers must not be empty.");
}

// Check for NaN
if (a.array().isNaN().any() || b.array().isNaN().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Check for infinity
if (a.array().isInf().any() || b.array().isInf().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Apply the log_add_exp operation directly
for (int i = 0; i < a.rows(); ++i) {
result(i) = stan::math::log_add_exp(a(i), b(i));
}

return result; // Return the result vector
}

// Specialization for nested fvar types
template <typename T>
inline auto log_add_exp(
const Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>& a,
const Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>&
b) {
using nested_fvar_mat_type
= Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>;
nested_fvar_mat_type result(a.rows(), a.cols());

// Check for empty inputs
if (a.size() == 0 || b.size() == 0) {
throw std::invalid_argument("Input containers must not be empty.");
}

// Check for NaN
if (a.array().isNaN().any() || b.array().isNaN().any()) {
result.setConstant(stan::math::fvar<stan::math::fvar<double>>(
std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Check for infinity
if (a.array().isInf().any() || b.array().isInf().any()) {
result.setConstant(stan::math::fvar<stan::math::fvar<double>>(
std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Implement the logic for log_add_exp for nested fvar types
for (int i = 0; i < a.rows(); ++i) {
for (int j = 0; j < a.cols(); ++j) {
auto inner_a = a(i, j);
auto inner_b = b(i, j);
result(i, j) = stan::math::log_add_exp(inner_a, inner_b);
}
}

return result; // Return the result matrix
}

} // namespace math
} // namespace stan

#endif
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
#include <stan/math/prim/fun/log_softmax.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp_signed.hpp>
#include <stan/math/prim/fun/log_add_exp.hpp>
#include <stan/math/prim/fun/logical_and.hpp>
#include <stan/math/prim/fun/logical_eq.hpp>
#include <stan/math/prim/fun/logical_gt.hpp>
Expand Down
159 changes: 159 additions & 0 deletions stan/math/prim/fun/log_add_exp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#ifndef STAN_MATH_PRIM_FUN_LOG_ADD_EXP_HPP
#define STAN_MATH_PRIM_FUN_LOG_ADD_EXP_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <cmath>
#include <vector>
#include <algorithm>
#include <stan/math/prim/err/check_matching_dims.hpp>
#include <stan/math/prim/meta/is_eigen.hpp>

namespace stan {
namespace math {

/**
* Calculates the elementwise sum of exponentials without overflow.
*
* \f$\log (\exp(a) + \exp(b)) = m + \log(\exp(a-m) + \exp(b-m))\f$,
*
* where \f$m = max(a, b)\f$.
*
* @tparam T1 type of the first variable
* @tparam T2 type of the second variable
* @param a the first variable
* @param b the second variable
*/

template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr,
require_all_stan_scalar_t<T1, T2>* = nullptr>
inline return_type_t<T1, T2> log_add_exp(const T2& a, const T1& b) {
if (a == NEGATIVE_INFTY) {
return b;
}
if (b == NEGATIVE_INFTY) {
return a;
}
if (a == INFTY || b == INFTY) {
return INFTY;
}

const double max_val = std::max(a, b);
return max_val + std::log(std::exp(a - max_val) + std::exp(b - max_val));
}

/**
* Calculates the element-wise log sum of exponentials for two containers.
* For vectors a and b, computes log(exp(a[i]) + exp(b[i])) for each element i.
* If sizes don't match, uses the smaller size.
*
* @tparam T1 type of first container
* @tparam T2 type of second container
* @param a First input container
* @param b Second input container
* @return Container with element-wise log_add_exp results
*/
template <typename T, require_container_st<std::is_arithmetic, T>* = nullptr>
inline auto log_add_exp(const T& a, const T& b) {
// Check if sizes are compatible
if constexpr (stan::is_eigen<T>::value) {
// Check if both matrices/vectors have the same dimensions
stan::math::check_matching_dims("log_add_exp", "a", a, "b", b);

// Determine the number of rows and columns for the result
size_t rows = a.rows();
size_t cols = b.cols();
using return_t = return_type_t<T>;

Eigen::Matrix<return_t, Eigen::Dynamic, Eigen::Dynamic> result(rows, cols);

// Iterate over each element
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
double a_val = (a.cols() == 1)
? a(i, 0)
: a(i, j); // Handle column vector or matrix
double b_val = (b.rows() == 1)
? b(0, j)
: b(i, j); // Handle row vector or matrix

if (a_val == NEGATIVE_INFTY) {
result(i, j) = b_val;
} else if (b_val == NEGATIVE_INFTY) {
result(i, j) = a_val;
} else if (a_val == INFTY || b_val == INFTY) {
result(i, j) = INFTY;
} else {
result(i, j) = log_sum_exp(a_val, b_val);
}
}
}

return result;
} else if constexpr (std::is_same_v<T, std::vector<typename T::value_type>>) {
// Handle std::vector
if (a.size() != b.size()) {
throw std::invalid_argument("Sizes of x and y must match.");
}

using return_t = return_type_t<T>;
std::vector<return_t> result(a.size());

for (size_t i = 0; i < a.size(); ++i) {
double a_val = a[i];
double b_val = b[i];

if (a_val == NEGATIVE_INFTY) {
result[i] = b_val;
} else if (b_val == NEGATIVE_INFTY) {
result[i] = a_val;
} else if (a_val == INFTY || b_val == INFTY) {
result[i] = INFTY;
} else {
result[i] = log_sum_exp(a_val, b_val);
}
}

return result;
} else {
throw std::invalid_argument("Unsupported container type.");
}
}

/**
* Enables the vectorized application of the log_add_exp function,
* when the first and/or second arguments are containers.
*
* @tparam T1
* @tparam T2
* @param a
* @param b
* @return auto
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
inline auto log_add_exp(const T1& a, const T2& b) {
// Check if both are Eigen/vectors
if constexpr (stan::is_eigen<T1>::value && stan::is_eigen<T2>::value) {
// Check if both matrices/vectors have the same dimensions
stan::math::check_matching_dims("log_add_exp", "a", a, "b", b);
} else {
// Check if sizes are compatible for other types
if (a.size() != b.size()) {
throw std::invalid_argument(
"Sizes of x and y must match or be compatible.");
}
}

// If dimensions are verified to match, apply the operation
return apply_scalar_binary(
a, b, [](const auto& c, const auto& d) { return log_add_exp(c, d); });
}

} // namespace math
} // namespace stan

#endif
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
#include <stan/math/rev/fun/log_rising_factorial.hpp>
#include <stan/math/rev/fun/log_softmax.hpp>
#include <stan/math/rev/fun/log_sum_exp.hpp>
#include <stan/math/rev/fun/log_add_exp.hpp>
#include <stan/math/rev/fun/logit.hpp>
#include <stan/math/rev/fun/matrix_exp_multiply.hpp>
#include <stan/math/rev/fun/matrix_power.hpp>
Expand Down
Loading