Skip to content

Commit

Permalink
Several fixes, provisional WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Nov 10, 2023
1 parent 9016ef8 commit ef8e031
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
40 changes: 30 additions & 10 deletions src/expression_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,9 @@ auto diff_tensors_impl(const std::vector<expression> &v_ex, const std::vector<ex
// of the previous-order derivatives.
tmp_v_idx.first = 0;
tmp_v_idx.second.clear();
tmp_v_idx.second.emplace_back(0, cur_order);
if (cur_order != 0u) {
tmp_v_idx.second.emplace_back(0, cur_order);
}

const auto prev_begin = search_diff_map(tmp_v_idx);
assert(prev_begin != diff_map.end());
Expand Down Expand Up @@ -1477,8 +1479,12 @@ std::uint32_t dtens::get_order() const
// in the map (specifically, it is
// the last element in the indices
// vector of the last derivative).
assert(!(end() - 1)->first.second.empty());
return (end() - 1)->first.second.back().second;
const auto &sv = (end() - 1)->first.second;
if (sv.empty()) {
return 0;
} else {
return sv.back().second;
}
}

dtens::iterator dtens::find(const v_idx_t &vidx) const
Expand All @@ -1493,6 +1499,12 @@ dtens::iterator dtens::find(const v_idx_t &vidx) const
return end();
}

// The size of vidx must be consistent with the number
// of diff args.
if (vidx.size() - 1u != get_nvars()) {
return end();
}

// Turn vidx into sparse format.
detail::dtens_sv_idx_t s_vidx{vidx[0], {}};
for (decltype(vidx.size()) i = 1; i < vidx.size(); ++i) {
Expand Down Expand Up @@ -1538,7 +1550,10 @@ dtens::subrange dtens::get_derivatives(std::uint32_t order) const

// Create the indices vector corresponding to the first derivative
// of component 0 for the given order in the map.
detail::dtens_sv_idx_t s_vidx{0, {{0, order}}};
detail::dtens_sv_idx_t s_vidx{0, {}};
if (order != 0u) {
s_vidx.second.emplace_back(0, order);
}

// Locate the corresponding derivative in the map.
// NOTE: this could be end() for invalid order.
Expand All @@ -1560,8 +1575,10 @@ dtens::subrange dtens::get_derivatives(std::uint32_t order) const
// map is empty, and we handled this corner case earlier.
assert(get_nouts() > 0u);
s_vidx.first = get_nouts() - 1u;
assert(get_nvars() > 0u);
s_vidx.second[0].first = get_nvars() - 1u;
if (order != 0u) {
assert(get_nvars() > 0u);
s_vidx.second[0].first = get_nvars() - 1u;
}

// NOTE: this could be end() for invalid order.
auto e = p_impl->m_map.find(s_vidx);
Expand Down Expand Up @@ -1596,7 +1613,10 @@ dtens::subrange dtens::get_derivatives(std::uint32_t component, std::uint32_t or

// Create the indices vector corresponding to the first derivative
// for the given order and component in the map.
detail::dtens_sv_idx_t s_vidx{component, {{0, order}}};
detail::dtens_sv_idx_t s_vidx{component, {}};
if (order != 0u) {
s_vidx.second.emplace_back(0, order);
}

// Locate the corresponding derivative in the map.
// NOTE: this could be end() for invalid component/order.
Expand All @@ -1615,7 +1635,9 @@ dtens::subrange dtens::get_derivatives(std::uint32_t component, std::uint32_t or
// Modify vidx so that it now refers to the last derivative
// for the given order and component in the map.
assert(get_nvars() > 0u);
s_vidx.second[0].first = get_nvars() - 1u;
if (order != 0u) {
s_vidx.second[0].first = get_nvars() - 1u;
}

// NOTE: this could be end() for invalid component/order.
auto e = p_impl->m_map.find(s_vidx);
Expand Down Expand Up @@ -1692,8 +1714,6 @@ std::uint32_t dtens::get_nvars() const

if (p_impl->m_map.empty()) {
assert(ret == 0u);
} else {
assert(!begin()->first.second.empty());
}

#endif
Expand Down
4 changes: 4 additions & 0 deletions test/expression_diff_tensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ TEST_CASE("dtens basics")
REQUIRE(dt3.get_args() == dt2.get_args());
}

#if 0

TEST_CASE("fixed centres check")
{
std::uniform_real_distribution<double> rdist(-10., 10.);
Expand Down Expand Up @@ -577,6 +579,8 @@ TEST_CASE("speelpenning check")
}
}

#endif

TEST_CASE("speelpenning complexity")
{
fmt::print("Speelpenning's example\n");
Expand Down

0 comments on commit ef8e031

Please sign in to comment.