Skip to content

Commit

Permalink
Merge pull request #454 from bluescarni/pr/pow_log_exp
Browse files Browse the repository at this point in the history
Automatic pow -> log/exp transformation
  • Loading branch information
bluescarni authored Sep 16, 2024
2 parents 937b884 + 26c888e commit 7c01a83
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 61 deletions.
3 changes: 3 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Changelog
New
~~~

- Non-number exponents for the ``pow()`` function
are now supported in Taylor integrators
(`#454 <https://github.com/bluescarni/heyoka/pull/454>`__).
- It is now possible to initialise a Taylor integrator
with an empty initial state vector. This will result
in zero-initialization of the state vector
Expand Down
49 changes: 25 additions & 24 deletions src/math/pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,14 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
// Fetch the pow eval algo.
const auto pea = get_pow_eval_algo(f);

// Codegen the exponent.
auto *expo = taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size);

// Fetch the internal vector type.
auto *vec_t = make_vector_type(fp_t, batch_size);

if (order == 0u) {
return pea.eval_f(
s, {taylor_fetch_diff(arr, u_idx, 0, n_uvars), taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size)});
return pea.eval_f(s, {taylor_fetch_diff(arr, u_idx, 0, n_uvars), expo});
}

// Special case for sqrt().
Expand All @@ -514,7 +519,6 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
}

// The general case.
auto &builder = s.builder();

// NOTE: iteration in the [0, order) range
// (i.e., order *not* included).
Expand All @@ -524,27 +528,14 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
auto *v1 = taylor_fetch_diff(arr, idx, j, n_uvars);

// Compute the scalar factor: order * num - j * (num + 1).
auto scal_f = [&]() -> llvm::Value * {
if constexpr (std::is_same_v<U, number>) {
return vector_splat(
builder,
llvm_codegen(s, fp_t,
number_like(s, fp_t, static_cast<double>(order)) * num
- number_like(s, fp_t, static_cast<double>(j)) * (num + number_like(s, fp_t, 1.))),
batch_size);
} else {
auto pc = taylor_codegen_numparam(s, fp_t, num, par_ptr, batch_size);
auto *jvec = vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast<double>(j))), batch_size);
auto *ordvec
= vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast<double>(order))), batch_size);
auto *onevec = vector_splat(builder, llvm_codegen(s, fp_t, number(1.)), batch_size);
auto *jvec = llvm_codegen(s, vec_t, number(static_cast<double>(j)));
auto *ordvec = llvm_codegen(s, vec_t, number(static_cast<double>(order)));
auto *onevec = llvm_codegen(s, vec_t, number(1.));

auto tmp1 = llvm_fmul(s, ordvec, pc);
auto tmp2 = llvm_fmul(s, jvec, llvm_fadd(s, pc, onevec));
auto *tmp1 = llvm_fmul(s, ordvec, expo);
auto *tmp2 = llvm_fmul(s, jvec, llvm_fadd(s, expo, onevec));

return llvm_fsub(s, tmp1, tmp2);
}
}();
auto *scal_f = llvm_fsub(s, tmp1, tmp2);

// Add scal_f*v0*v1 to the sum.
sum.push_back(llvm_fmul(s, scal_f, llvm_fmul(s, v0, v1)));
Expand All @@ -554,14 +545,16 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &s, llvm::Type *fp_t, const pow_imp
auto *ret_acc = pairwise_sum(s, sum);

// Compute the final divisor: order * (zero-th derivative of u_idx).
auto *ord_f = vector_splat(builder, llvm_codegen(s, fp_t, number(static_cast<double>(order))), batch_size);
auto *ord_f = llvm_codegen(s, vec_t, number(static_cast<double>(order)));
auto *b0 = taylor_fetch_diff(arr, u_idx, 0, n_uvars);
auto *div = llvm_fmul(s, ord_f, b0);

// Compute and return the result: ret_acc / div.
return llvm_fdiv(s, ret_acc, div);
}

// LCOV_EXCL_START

// All the other cases.
template <typename U1, typename U2, std::enable_if_t<!std::conjunction_v<is_num_param<U1>, is_num_param<U2>>, int> = 0>
llvm::Value *taylor_diff_pow_impl(llvm_state &, llvm::Type *, const pow_impl &, const U1 &, const U2 &,
Expand All @@ -572,19 +565,23 @@ llvm::Value *taylor_diff_pow_impl(llvm_state &, llvm::Type *, const pow_impl &,
"An invalid argument type was encountered while trying to build the Taylor derivative of a pow()");
}

// LCOV_EXCL_STOP

llvm::Value *taylor_diff_pow(llvm_state &s, llvm::Type *fp_t, const pow_impl &f, const std::vector<std::uint32_t> &deps,
const std::vector<llvm::Value *> &arr, llvm::Value *par_ptr, std::uint32_t n_uvars,
std::uint32_t order, std::uint32_t idx, std::uint32_t batch_size)
{
assert(f.args().size() == 2u);

// LCOV_EXCL_START
if (!deps.empty()) {
throw std::invalid_argument(
fmt::format("An empty hidden dependency vector is expected in order to compute the Taylor "
"derivative of the exponentiation, but a vector of size {} was passed "
"instead",
deps.size()));
}
// LCOV_EXCL_STOP

return std::visit(
[&](const auto &v1, const auto &v2) {
Expand Down Expand Up @@ -898,7 +895,7 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &s, llvm::Type *fp_t, con
auto *ft = llvm::FunctionType::get(val_t, fargs, false);
// Create the function
f = llvm::Function::Create(ft, llvm::Function::InternalLinkage, fname, &md);
assert(f != nullptr);
assert(f != nullptr); // LCOV_EXCL_LINE

// Fetch the necessary function arguments.
auto ord = f->args().begin();
Expand Down Expand Up @@ -969,6 +966,8 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &s, llvm::Type *fp_t, con
return f;
}

// LCOV_EXCL_START

// All the other cases.
template <typename U1, typename U2, std::enable_if_t<!std::conjunction_v<is_num_param<U1>, is_num_param<U2>>, int> = 0>
llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &, llvm::Type *, const pow_impl &, const U1 &, const U2 &,
Expand All @@ -978,6 +977,8 @@ llvm::Function *taylor_c_diff_func_pow_impl(llvm_state &, llvm::Type *, const po
"of a pow() in compact mode");
}

// LCOV_EXCL_STOP

llvm::Function *taylor_c_diff_func_pow(llvm_state &s, llvm::Type *fp_t, const pow_impl &fn, std::uint32_t n_uvars,
std::uint32_t batch_size)
{
Expand Down
78 changes: 78 additions & 0 deletions src/taylor_01.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

#include <algorithm>
#include <cassert>
#include <concepts>
#include <cstdint>
#include <deque>
#include <exception>
#include <iterator>
#include <limits>
#include <numeric>
#include <optional>
#include <stdexcept>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -46,6 +48,7 @@

#include <heyoka/config.hpp>
#include <heyoka/detail/cm_utils.hpp>
#include <heyoka/detail/func_cache.hpp>
#include <heyoka/detail/llvm_func_create.hpp>
#include <heyoka/detail/llvm_helpers.hpp>
#include <heyoka/detail/logging_impl.hpp>
Expand All @@ -54,7 +57,11 @@
#include <heyoka/detail/type_traits.hpp>
#include <heyoka/detail/visibility.hpp>
#include <heyoka/expression.hpp>
#include <heyoka/func.hpp>
#include <heyoka/llvm_state.hpp>
#include <heyoka/math/exp.hpp>
#include <heyoka/math/log.hpp>
#include <heyoka/math/pow.hpp>
#include <heyoka/math/prod.hpp>
#include <heyoka/math/sum.hpp>
#include <heyoka/number.hpp>
Expand Down Expand Up @@ -769,6 +776,74 @@ void taylor_decompose_replace_numbers(taylor_dc_t &dc, std::vector<expression>::
}
}

// NOLINTNEXTLINE(misc-no-recursion)
expression pow_to_explog(funcptr_map<expression> &func_map, const expression &ex)
{
return std::visit(
// NOLINTNEXTLINE(misc-no-recursion)
[&]<typename T>(const T &v) {
if constexpr (std::same_as<T, func>) {
const auto *f_id = v.get_ptr();

// Check if we already performed the transformation on ex.
if (const auto it = func_map.find(f_id); it != func_map.end()) {
return it->second;
}

// Perform the transformation on the function arguments.
std::vector<expression> new_args;
new_args.reserve(v.args().size());
for (const auto &orig_arg : v.args()) {
new_args.push_back(pow_to_explog(func_map, orig_arg));
}

// Prepare the return value.
std::optional<expression> retval;

if (v.template extract<detail::pow_impl>() != nullptr
&& !std::holds_alternative<number>(new_args[1].value())) {
// The function is a pow() and the exponent is not a number: transform x**y -> exp(y*log(x)).
//
// NOTE: do not call directly log(new_args[0]) in order to avoid constant folding when the base
// is a number. For instance, if we have pow(2_dbl, par[0]), then we would end up computing
// log(2) in double precision. This would result in an inaccurate result if the fp type
// or precision in use during integration is higher than double.
// NOTE: because the exponent is not a number, no other constant folding should take
// place here.
retval.emplace(exp(new_args[1] * expression{func{detail::log_impl(new_args[0])}}));
} else {
// Create a copy of v with the new arguments.
retval.emplace(v.copy(std::move(new_args)));
}

// Put the return value into the cache.
[[maybe_unused]] const auto [_, flag] = func_map.emplace(f_id, *retval);
// NOTE: an expression cannot contain itself.
assert(flag); // LCOV_EXCL_LINE

return std::move(*retval);
} else {
return ex;
}
},
ex.value());
}

// Helper to transform x**y -> exp(y*log(x)), if y is not a number.
std::vector<expression> pow_to_explog(const std::vector<expression> &v_ex)
{
funcptr_map<expression> func_map;

std::vector<expression> retval;
retval.reserve(v_ex.size());

for (const auto &e : v_ex) {
retval.push_back(pow_to_explog(func_map, e));
}

return retval;
}

} // namespace

} // namespace detail
Expand Down Expand Up @@ -798,6 +873,9 @@ taylor_decompose_sys(const std::vector<std::pair<expression, expression>> &sys_,
std::ranges::transform(sys_, std::back_inserter(all_ex), &std::pair<expression, expression>::second);
all_ex.insert(all_ex.end(), sv_funcs_.begin(), sv_funcs_.end());

// Transform x**y -> exp(y*log(x)), if y is not a number.
all_ex = detail::pow_to_explog(all_ex);

// Transform sums into subs.
all_ex = detail::sum_to_sub(all_ex);

Expand Down
16 changes: 16 additions & 0 deletions test/llvm_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3397,6 +3397,20 @@ TEST_CASE("is_finite scalar mp")

#endif

// NOTE: is_natural() appears not to be working on ppc, giving the following error:
//
// LLVM ERROR: Cannot select: 0x63b95812620: v1i128 = bitcast 0x63b957f9700
// 0x63b957f9700: f128,ch = load<(load (s128) from %ir.2 + 16, align 1)> 0x63b95acca30, 0x63b957f9690, undef:i64
// 0x63b957f9690: i64 = add nuw 0x63b957f9380, Constant:i64<16>
// 0x63b957f9380: i64,ch = CopyFromReg 0x63b95acca30, Register:i64 %1
// 0x63b957f9310: i64 = Register %1
// 0x63b957f9620: i64 = Constant<16>
// 0x63b957f9460: i64 = undef
// In function: hey_is_natural
//
// This seems like an instruction selection problem specific to the ppc backend.
#if !defined(HEYOKA_ARCH_PPC)

TEST_CASE("is_natural scalar")
{
using detail::llvm_is_natural;
Expand Down Expand Up @@ -3578,3 +3592,5 @@ TEST_CASE("is_natural scalar mp")
}

#endif

#endif
Loading

0 comments on commit 7c01a83

Please sign in to comment.