Skip to content

Commit

Permalink
Bring repo in line with Standard
Browse files Browse the repository at this point in the history
* Rename accessor_scaled to scaled_accessor, and
  accessor_conjugate to conjugated_accessor

* Remove proxy references from both accessor types

* Add missing conditional explicit to both accessors'
  constructors

* Otherwise bring conjugated_accessor and scaled_accessor
  in line with the Standard

* Fix namespaces where needed

* Fix implementation of conj-if-needed, and add implementations (and
  tests) of real-if-needed and imag-if-needed

* Add tests mixing result of scaled with an mdspan whose accessor
  uses a proxy reference as its reference type.  This increases
  confidence that the proxy-reference-free design works.
  • Loading branch information
mhoemmen committed Sep 27, 2024
1 parent 22833a9 commit d18fd5c
Show file tree
Hide file tree
Showing 23 changed files with 520 additions and 701 deletions.
8 changes: 4 additions & 4 deletions include/experimental/__p1673_bits/blas1_vector_abs_sum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ Scalar vector_abs_sum(
using value_type = typename decltype(v)::value_type;
using sum_type =
decltype(init +
impl::abs_if_needed(impl::real_part(std::declval<value_type>())) +
impl::abs_if_needed(impl::imag_part(std::declval<value_type>())));
impl::abs_if_needed(impl::real_if_needed(std::declval<value_type>())) +
impl::abs_if_needed(impl::imag_if_needed(std::declval<value_type>())));
static_assert(std::is_convertible_v<sum_type, Scalar>);
// TODO Implement the Remarks in para 4.

Expand All @@ -100,8 +100,8 @@ Scalar vector_abs_sum(
}
else {
for (SizeType i = 0; i < numElt; ++i) {
init += impl::abs_if_needed(impl::real_part(v(i)));
init += impl::abs_if_needed(impl::imag_part(v(i)));
init += impl::abs_if_needed(impl::real_if_needed(v(i)));
init += impl::abs_if_needed(impl::imag_if_needed(v(i)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,15 +534,15 @@ void hermitian_matrix_rank_1_update(

if constexpr (std::is_same_v<Triangle, lower_triangle_t>) {
for (size_type j = 0; j < A.extent(1); ++j) {
A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = j; i < A.extent(0); ++i) {
A(i,j) += alpha * x(i) * impl::conj_if_needed(x(j));
}
}
}
else {
for (size_type j = 0; j < A.extent(1); ++j) {
A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = 0; i <= j; ++i) {
A(i,j) += alpha * x(i) * impl::conj_if_needed(x(j));
}
Expand Down Expand Up @@ -643,15 +643,15 @@ void hermitian_matrix_rank_1_update(

if constexpr (std::is_same_v<Triangle, lower_triangle_t>) {
for (size_type j = 0; j < A.extent(1); ++j) {
A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = j; i < A.extent(0); ++i) {
A(i,j) += x(i) * impl::conj_if_needed(x(j));
}
}
}
else {
for (size_type j = 0; j < A.extent(1); ++j) {
A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = 0; i <= j; ++i) {
A(i,j) += x(i) * impl::conj_if_needed(x(j));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ void hermitian_matrix_rank_2_update(
const size_type i_lower = lower_tri ? j : size_type(0);
const size_type i_upper = lower_tri ? A.extent(0) : j+1;

A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = i_lower; i < i_upper; ++i) {
A(i,j) += x(i) * impl::conj_if_needed(y(j)) + y(i) * impl::conj_if_needed(x(j));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ void hermitian_matrix_vector_product(

if constexpr (std::is_same_v<Triangle, lower_triangle_t>) {
for (size_type j = 0; j < A.extent(1); ++j) {
y(j) += impl::real_part(A(j,j)) * x(j);
y(j) += impl::real_if_needed(A(j,j)) * x(j);
for (size_type i = j + size_type(1); i < A.extent(0); ++i) {
const auto A_ij = A(i,j);
y(i) += A_ij * x(j);
Expand All @@ -782,7 +782,7 @@ void hermitian_matrix_vector_product(
y(i) += A_ij * x(j);
y(j) += impl::conj_if_needed(A_ij) * x(i);
}
y(j) += impl::real_part(A(j,j)) * x(j);
y(j) += impl::real_if_needed(A(j,j)) * x(j);
}
}
}
Expand Down Expand Up @@ -897,7 +897,7 @@ void hermitian_matrix_vector_product(

if constexpr (std::is_same_v<Triangle, lower_triangle_t>) {
for (size_type j = 0; j < A.extent(1); ++j) {
z(j) += impl::real_part(A(j,j)) * x(j);
z(j) += impl::real_if_needed(A(j,j)) * x(j);
for (size_type i = j + size_type(1); i < A.extent(0); ++i) {
const auto A_ij = A(i,j);
z(i) += A_ij * x(j);
Expand All @@ -912,7 +912,7 @@ void hermitian_matrix_vector_product(
z(i) += A_ij * x(j);
z(j) += impl::conj_if_needed(A_ij) * x(i);
}
z(j) += impl::real_part(A(j,j)) * x(j);
z(j) += impl::real_if_needed(A(j,j)) * x(j);
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions include/experimental/__p1673_bits/blas3_matrix_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ constexpr bool valid_input_blas_accessor()

using def_acc_type = default_accessor<elt_type>;
using conj_def_acc_type = conjugated_accessor<def_acc_type>;
using scal_def_acc_type = accessor_scaled<val_type, def_acc_type>;
using scal_conj_acc_type = accessor_scaled<val_type, conj_def_acc_type>;
using scal_def_acc_type = scaled_accessor<val_type, def_acc_type>;
using scal_conj_acc_type = scaled_accessor<val_type, conj_def_acc_type>;
using conj_scal_acc_type = conjugated_accessor<scal_def_acc_type>;

// The two matrices' accessor types need not be the same.
Expand Down Expand Up @@ -1730,7 +1730,7 @@ void hermitian_matrix_product(
for (size_type k = 0; k < i; ++k){
C(i,j) += A(i,k) * B(k,j);
}
C(i,j) += impl::real_part(A(i,i)) * B(i,j);
C(i,j) += impl::real_if_needed(A(i,i)) * B(i,j);
for (size_type k = i+1; k < A.extent(0); ++k){
C(i,j) += impl::conj_if_needed(A(k,i)) * B(k,j);
}
Expand All @@ -1744,7 +1744,7 @@ void hermitian_matrix_product(
for (size_type k = 0; k < i; ++k) {
C(i,j) += impl::conj_if_needed(A(k,i)) * B(k,j);
}
C(i,j) += impl::real_part(A(i,i)) * B(i,j);
C(i,j) += impl::real_if_needed(A(i,i)) * B(i,j);
for (size_type k = i+1; k < A.extent(1); ++k) {
C(i,j) += A(i,k) * B(k,j);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ void hermitian_matrix_rank_2k_update(
for (size_type j = 0; j < C.extent(1); ++j) {
const size_type i_lower = lower_tri ? j : size_type(0);
const size_type i_upper = lower_tri ? C.extent(0) : j+1;
C(j,j) = impl::real_part(C(j,j));
C(j,j) = impl::real_if_needed(C(j,j));
for (size_type i = i_lower; i < i_upper; ++i) {
for (size_type k = 0; k < A.extent(1); ++k) {
C(i,j) += A(i,k) * impl::conj_if_needed(B(j,k)) + B(i,k) * impl::conj_if_needed(A(j,k));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ void hermitian_matrix_rank_k_update(
for (size_type j = 0; j < C.extent(1); ++j) {
const size_type i_lower = lower_tri ? j : size_type(0);
const size_type i_upper = lower_tri ? C.extent(0) : j+1;
C(j, j) = impl::real_part(C(j, j));
C(j, j) = impl::real_if_needed(C(j, j));
for (size_type i = i_lower; i < i_upper; ++i) {
for (size_type k = 0; k < A.extent(1); ++k) {
C(i, j) += alpha * A(i, k) * impl::conj_if_needed(A(j, k));
Expand Down Expand Up @@ -456,7 +456,7 @@ void hermitian_matrix_rank_k_update(
using size_type = std::common_type_t<SizeType_A, SizeType_C>;

for (size_type j = 0; j < C.extent(1); ++j) {
C(j, j) = impl::real_part(C(j, j));
C(j, j) = impl::real_if_needed(C(j, j));
const size_type i_lower = lower_tri ? j : size_type(0);
const size_type i_upper = lower_tri ? C.extent(0) : j+1;
for (size_type i = i_lower; i < i_upper; ++i) {
Expand Down
3 changes: 3 additions & 0 deletions include/experimental/__p1673_bits/conjugated.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class conjugated_accessor {
class OtherNestedAccessor,
/* requires */ (std::is_convertible_v<NestedAccessor, const OtherNestedAccessor&>)
)
#if defined(__cpp_conditional_explicit)
explicit(!std::is_convertible_v<OtherNestedAccessor, NestedAccessor>)
#endif
constexpr conjugated_accessor(const conjugated_accessor<OtherNestedAccessor>& other)
: nested_accessor_(other.nested_accessor())
{}
Expand Down
8 changes: 4 additions & 4 deletions include/experimental/__p1673_bits/imag_if_needed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
#include <complex>
#include <type_traits>

namespace std {
namespace experimental {
namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
namespace MDSPAN_IMPL_PROPOSED_NAMESPACE {
inline namespace __p1673_version_0 {
namespace linalg {
namespace impl{
Expand Down Expand Up @@ -90,7 +90,7 @@ constexpr inline auto imag_if_needed = [](const auto& t)
} // end namespace impl
} // end namespace linalg
} // end inline namespace __p1673_version_0
} // end namespace experimental
} // end namespace std
} // end namespace MDSPAN_IMPL_PROPOSED_NAMESPACE
} // end namespace MDSPAN_IMPL_STANDARD_NAMESPACE

#endif //LINALG_INCLUDE_EXPERIMENTAL___P1673_BITS_IMAG_IF_NEEDED_HPP_
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
#ifndef LINALG_INCLUDE_EXPERIMENTAL___P1673_BITS_PROXY_REFERENCE_HPP_
#define LINALG_INCLUDE_EXPERIMENTAL___P1673_BITS_PROXY_REFERENCE_HPP_

#include "experimental/__p1673_bits/conj_if_needed.hpp"
#if defined(__cpp_lib_atomic_ref) && defined(LINALG_ENABLE_ATOMIC_REF)
# include <atomic>
#endif
Expand All @@ -67,56 +66,6 @@ template<class U>
static constexpr bool is_atomic_ref_not_arithmetic_v<std::atomic_ref<U>> = ! std::is_arithmetic_v<U>;
#endif

template<class T>
T imag_part_impl(const T& t, std::false_type)
{
return T{};
}

template<class T>
auto imag_part_impl(const T& t, std::true_type)
{
if constexpr (std::is_arithmetic_v<T>) {
return T{};
} else {
return imag(t);
}
}

template<class T>
auto imag_part(const T& t)
{
return imag_part_impl(t, has_imag<T>{});
}

template<class T>
T real_part_impl(const T& t, std::false_type)
{
return t;
}

template<class T>
auto real_part_impl(const T& t, std::true_type)
{
if constexpr (std::is_arithmetic_v<T>) {
return t;
} else {
return real(t);
}
}

template<class T>
auto real_part(const T& t)
{
return real_part_impl(t, has_real<T>{});
}

// template<class R>
// R imag_part(const std::complex<R>& z)
// {
// return std::imag(z);
// }

// A "tag" for identifying the proxy reference types in this proposal.
// It's helpful for this tag to be a complete type, so that we can use
// it inside proxy_reference (proxy_reference isn't really complete
Expand Down Expand Up @@ -224,11 +173,11 @@ class proxy_reference : proxy_reference_base {
}

friend auto real(const derived_type& x) {
return real_part(value_type(static_cast<const this_type&>(x)));
return impl::real_if_needed(value_type(static_cast<const this_type&>(x)));
}

friend auto imag(const derived_type& x) {
return imag_part(value_type(static_cast<const this_type&>(x)));
return impl::imag_if_needed(value_type(static_cast<const this_type&>(x)));
}

friend auto conj(const derived_type& x) {
Expand Down
8 changes: 4 additions & 4 deletions include/experimental/__p1673_bits/real_if_needed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
#include <complex>
#include <type_traits>

namespace std {
namespace experimental {
namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
namespace MDSPAN_IMPL_PROPOSED_NAMESPACE {
inline namespace __p1673_version_0 {
namespace linalg {
namespace impl{
Expand Down Expand Up @@ -90,7 +90,7 @@ constexpr inline auto real_if_needed = [](const auto& t)
} // end namespace impl
} // end namespace linalg
} // end inline namespace __p1673_version_0
} // end namespace experimental
} // end namespace std
} // end namespace MDSPAN_IMPL_PROPOSED_NAMESPACE
} // end namespace MDSPAN_IMPL_STANDARD_NAMESPACE

#endif //LINALG_INCLUDE_EXPERIMENTAL___P1673_BITS_REAL_IF_NEEDED_HPP_
51 changes: 28 additions & 23 deletions include/experimental/__p1673_bits/scaled.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,62 +50,67 @@ namespace MDSPAN_IMPL_PROPOSED_NAMESPACE {
inline namespace __p1673_version_0 {
namespace linalg {

template<class ScalingFactor, class Accessor>
template<class ScalingFactor, class NestedAccessor>
class scaled_accessor {
public:
using element_type = decltype(std::declval<ScalingFactor>() * std::declval<typename Accessor::element_type>());
using reference = element_type;
using data_handle_type = typename Accessor::data_handle_type;
using element_type =
std::add_const_t<decltype(std::declval<ScalingFactor>() * std::declval<typename NestedAccessor::element_type>())>;
using reference = std::remove_const_t<element_type>;
using data_handle_type = typename NestedAccessor::data_handle_type;
using offset_policy =
scaled_accessor<ScalingFactor, typename Accessor::offset_policy>;
scaled_accessor<ScalingFactor, typename NestedAccessor::offset_policy>;

scaled_accessor(const ScalingFactor& scaling_factor, const Accessor& accessor) :
scaling_factor_(scaling_factor),
accessor_(accessor)
{}
constexpr scaled_accessor() = default;

MDSPAN_TEMPLATE_REQUIRES(
class OtherScalingFactor,
class OtherNestedAccessor,
/* requires */ (
std::is_constructible_v<Accessor, const OtherNestedAccessor&> &&
std::is_constructible_v<NestedAccessor, const OtherNestedAccessor&> &&
std::is_constructible_v<ScalingFactor, OtherScalingFactor>
)
)
scaled_accessor(const scaled_accessor<OtherScalingFactor, OtherNestedAccessor>& other) :
#if defined(__cpp_conditional_explicit)
explicit(!std::is_convertible_v<OtherNestedAccessor, NestedAccessor>)
#endif
constexpr scaled_accessor(const scaled_accessor<OtherScalingFactor, OtherNestedAccessor>& other) :
scaling_factor_(other.scaling_factor()),
accessor_(other.nested_accessor())
nested_accessor_(other.nested_accessor())
{}

constexpr scaled_accessor(const ScalingFactor& s, const NestedAccessor& a) :
scaling_factor_(s),
nested_accessor_(a)
{}

reference access(data_handle_type p, ::std::size_t i) const
noexcept(noexcept(scaling_factor_* typename Accessor::element_type(accessor_.access(p, i)))) {
return scaling_factor_ * typename Accessor::element_type(accessor_.access(p, i));
constexpr reference access(data_handle_type p, ::std::size_t i) const {
return scaling_factor_ * typename NestedAccessor::element_type(nested_accessor_.access(p, i));
}

typename offset_policy::data_handle_type
offset(data_handle_type p, ::std::size_t i) const noexcept {
return accessor_.offset(p, i);
constexpr offset(data_handle_type p, ::std::size_t i) const {
return nested_accessor_.offset(p, i);
}

Accessor nested_accessor() const {
return accessor_;
constexpr NestedAccessor nested_accessor() const noexcept {
return nested_accessor_;
}

ScalingFactor scaling_factor() const {
constexpr ScalingFactor scaling_factor() const noexcept {
return scaling_factor_;
}

private:
ScalingFactor scaling_factor_;
Accessor accessor_;
NestedAccessor nested_accessor_;
};

namespace impl {

template<class ScalingFactor,
class Accessor>
class NestedAccessor>
using scaled_element_type =
std::add_const_t<typename scaled_accessor<ScalingFactor, Accessor>::element_type>;
std::add_const_t<typename scaled_accessor<ScalingFactor, NestedAccessor>::element_type>;

} // namespace impl

Expand Down
Loading

0 comments on commit d18fd5c

Please sign in to comment.