Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring conjugated_accessor and scaled_accessor in line with the C++ Working Draft #260

Merged
merged 16 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/experimental/__p1673_bits/blas1_vector_abs_sum.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ Scalar vector_abs_sum(
using value_type = typename decltype(v)::value_type;
using sum_type =
decltype(init +
impl::abs_if_needed(impl::real_part(std::declval<value_type>())) +
impl::abs_if_needed(impl::imag_part(std::declval<value_type>())));
impl::abs_if_needed(impl::real_if_needed(std::declval<value_type>())) +
impl::abs_if_needed(impl::imag_if_needed(std::declval<value_type>())));
static_assert(std::is_convertible_v<sum_type, Scalar>);
// TODO Implement the Remarks in para 4.

Expand All @@ -100,8 +100,8 @@ Scalar vector_abs_sum(
}
else {
for (SizeType i = 0; i < numElt; ++i) {
init += impl::abs_if_needed(impl::real_part(v(i)));
init += impl::abs_if_needed(impl::imag_part(v(i)));
init += impl::abs_if_needed(impl::real_if_needed(v(i)));
init += impl::abs_if_needed(impl::imag_if_needed(v(i)));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,15 +534,15 @@ void hermitian_matrix_rank_1_update(

if constexpr (std::is_same_v<Triangle, lower_triangle_t>) {
for (size_type j = 0; j < A.extent(1); ++j) {
A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = j; i < A.extent(0); ++i) {
A(i,j) += alpha * x(i) * impl::conj_if_needed(x(j));
}
}
}
else {
for (size_type j = 0; j < A.extent(1); ++j) {
A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = 0; i <= j; ++i) {
A(i,j) += alpha * x(i) * impl::conj_if_needed(x(j));
}
Expand Down Expand Up @@ -643,15 +643,15 @@ void hermitian_matrix_rank_1_update(

if constexpr (std::is_same_v<Triangle, lower_triangle_t>) {
for (size_type j = 0; j < A.extent(1); ++j) {
A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = j; i < A.extent(0); ++i) {
A(i,j) += x(i) * impl::conj_if_needed(x(j));
}
}
}
else {
for (size_type j = 0; j < A.extent(1); ++j) {
A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = 0; i <= j; ++i) {
A(i,j) += x(i) * impl::conj_if_needed(x(j));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ void hermitian_matrix_rank_2_update(
const size_type i_lower = lower_tri ? j : size_type(0);
const size_type i_upper = lower_tri ? A.extent(0) : j+1;

A(j,j) = impl::real_part(A(j,j));
A(j,j) = impl::real_if_needed(A(j,j));
for (size_type i = i_lower; i < i_upper; ++i) {
A(i,j) += x(i) * impl::conj_if_needed(y(j)) + y(i) * impl::conj_if_needed(x(j));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ void hermitian_matrix_vector_product(

if constexpr (std::is_same_v<Triangle, lower_triangle_t>) {
for (size_type j = 0; j < A.extent(1); ++j) {
y(j) += impl::real_part(A(j,j)) * x(j);
y(j) += impl::real_if_needed(A(j,j)) * x(j);
for (size_type i = j + size_type(1); i < A.extent(0); ++i) {
const auto A_ij = A(i,j);
y(i) += A_ij * x(j);
Expand All @@ -782,7 +782,7 @@ void hermitian_matrix_vector_product(
y(i) += A_ij * x(j);
y(j) += impl::conj_if_needed(A_ij) * x(i);
}
y(j) += impl::real_part(A(j,j)) * x(j);
y(j) += impl::real_if_needed(A(j,j)) * x(j);
}
}
}
Expand Down Expand Up @@ -897,7 +897,7 @@ void hermitian_matrix_vector_product(

if constexpr (std::is_same_v<Triangle, lower_triangle_t>) {
for (size_type j = 0; j < A.extent(1); ++j) {
z(j) += impl::real_part(A(j,j)) * x(j);
z(j) += impl::real_if_needed(A(j,j)) * x(j);
for (size_type i = j + size_type(1); i < A.extent(0); ++i) {
const auto A_ij = A(i,j);
z(i) += A_ij * x(j);
Expand All @@ -912,7 +912,7 @@ void hermitian_matrix_vector_product(
z(i) += A_ij * x(j);
z(j) += impl::conj_if_needed(A_ij) * x(i);
}
z(j) += impl::real_part(A(j,j)) * x(j);
z(j) += impl::real_if_needed(A(j,j)) * x(j);
}
}
}
Expand Down
20 changes: 10 additions & 10 deletions include/experimental/__p1673_bits/blas3_matrix_product.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ constexpr bool valid_input_blas_accessor()

using def_acc_type = default_accessor<elt_type>;
using conj_def_acc_type = conjugated_accessor<def_acc_type>;
using scal_def_acc_type = accessor_scaled<val_type, def_acc_type>;
using scal_conj_acc_type = accessor_scaled<val_type, conj_def_acc_type>;
using scal_def_acc_type = scaled_accessor<val_type, def_acc_type>;
using scal_conj_acc_type = scaled_accessor<val_type, conj_def_acc_type>;
using conj_scal_acc_type = conjugated_accessor<scal_def_acc_type>;

// The two matrices' accessor types need not be the same.
Expand Down Expand Up @@ -288,12 +288,12 @@ matrix_product_dispatch_to_blas()
}

template<class Accessor, class ValueType>
static constexpr bool is_compatible_accessor_scaled_v = false;
static constexpr bool is_compatible_scaled_accessor_v = false;

template<class ScalingFactor, class NestedAccessor, class ValueType>
static constexpr bool is_compatible_accessor_scaled_v<
accessor_scaled<ScalingFactor, NestedAccessor>, ValueType> =
std::is_same_v<typename accessor_scaled<ScalingFactor, NestedAccessor>::value_type, ValueType>;
static constexpr bool is_compatible_scaled_accessor_v<
scaled_accessor<ScalingFactor, NestedAccessor>, ValueType> =
std::is_same_v<typename scaled_accessor<ScalingFactor, NestedAccessor>::value_type, ValueType>;

template<class Accessor>
static constexpr bool is_conjugated_accessor_v = false;
Expand All @@ -309,12 +309,12 @@ extractScalingFactor(in_matrix_t A,
using acc_t = typename in_matrix_t::accessor_type;
using val_t = typename in_matrix_t::value_type;

if constexpr (is_compatible_accessor_scaled_v<acc_t, val_t>) {
if constexpr (is_compatible_scaled_accessor_v<acc_t, val_t>) {
return A.accessor.scale_factor();
} else if constexpr (is_conjugated_accessor_v<acc_t>) {
// conjugated(scaled(alpha, A)) means that both alpha and A are conjugated.
using nested_acc_t = decltype(A.accessor().nested_accessor());
if constexpr (is_compatible_accessor_scaled_v<nested_acc_t>) {
if constexpr (is_compatible_scaled_accessor_v<nested_acc_t>) {
return impl::conj_if_needed(extractScalingFactor(A.accessor.nested_accessor()));
} else {
return defaultValue;
Expand Down Expand Up @@ -1730,7 +1730,7 @@ void hermitian_matrix_product(
for (size_type k = 0; k < i; ++k){
C(i,j) += A(i,k) * B(k,j);
}
C(i,j) += impl::real_part(A(i,i)) * B(i,j);
C(i,j) += impl::real_if_needed(A(i,i)) * B(i,j);
for (size_type k = i+1; k < A.extent(0); ++k){
C(i,j) += impl::conj_if_needed(A(k,i)) * B(k,j);
}
Expand All @@ -1744,7 +1744,7 @@ void hermitian_matrix_product(
for (size_type k = 0; k < i; ++k) {
C(i,j) += impl::conj_if_needed(A(k,i)) * B(k,j);
}
C(i,j) += impl::real_part(A(i,i)) * B(i,j);
C(i,j) += impl::real_if_needed(A(i,i)) * B(i,j);
for (size_type k = i+1; k < A.extent(1); ++k) {
C(i,j) += A(i,k) * B(k,j);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ void hermitian_matrix_rank_2k_update(
for (size_type j = 0; j < C.extent(1); ++j) {
const size_type i_lower = lower_tri ? j : size_type(0);
const size_type i_upper = lower_tri ? C.extent(0) : j+1;
C(j,j) = impl::real_part(C(j,j));
C(j,j) = impl::real_if_needed(C(j,j));
for (size_type i = i_lower; i < i_upper; ++i) {
for (size_type k = 0; k < A.extent(1); ++k) {
C(i,j) += A(i,k) * impl::conj_if_needed(B(j,k)) + B(i,k) * impl::conj_if_needed(A(j,k));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ void hermitian_matrix_rank_k_update(
for (size_type j = 0; j < C.extent(1); ++j) {
const size_type i_lower = lower_tri ? j : size_type(0);
const size_type i_upper = lower_tri ? C.extent(0) : j+1;
C(j, j) = impl::real_part(C(j, j));
C(j, j) = impl::real_if_needed(C(j, j));
for (size_type i = i_lower; i < i_upper; ++i) {
for (size_type k = 0; k < A.extent(1); ++k) {
C(i, j) += alpha * A(i, k) * impl::conj_if_needed(A(j, k));
Expand Down Expand Up @@ -456,7 +456,7 @@ void hermitian_matrix_rank_k_update(
using size_type = std::common_type_t<SizeType_A, SizeType_C>;

for (size_type j = 0; j < C.extent(1); ++j) {
C(j, j) = impl::real_part(C(j, j));
C(j, j) = impl::real_if_needed(C(j, j));
const size_type i_lower = lower_tri ? j : size_type(0);
const size_type i_upper = lower_tri ? C.extent(0) : j+1;
for (size_type i = i_lower; i < i_upper; ++i) {
Expand Down
3 changes: 3 additions & 0 deletions include/experimental/__p1673_bits/conjugated.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ class conjugated_accessor {
class OtherNestedAccessor,
/* requires */ (std::is_convertible_v<NestedAccessor, const OtherNestedAccessor&>)
)
#if defined(__cpp_conditional_explicit)
explicit(!std::is_convertible_v<OtherNestedAccessor, NestedAccessor>)
#endif
constexpr conjugated_accessor(const conjugated_accessor<OtherNestedAccessor>& other)
: nested_accessor_(other.nested_accessor())
{}
Expand Down
96 changes: 96 additions & 0 deletions include/experimental/__p1673_bits/imag_if_needed.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
//@HEADER
// ************************************************************************
//
// Kokkos v. 2.0
// Copyright (2019) Sandia Corporation
//
// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
// the U.S. Government retains certain rights in this software. //
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the Corporation nor the names of the
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact Christian R. Trott ([email protected])
//
// ************************************************************************
//@HEADER
*/

#ifndef LINALG_INCLUDE_EXPERIMENTAL___P1673_BITS_IMAG_IF_NEEDED_HPP_
#define LINALG_INCLUDE_EXPERIMENTAL___P1673_BITS_IMAG_IF_NEEDED_HPP_

#include <complex>
#include <type_traits>

namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
namespace MDSPAN_IMPL_PROPOSED_NAMESPACE {
inline namespace __p1673_version_0 {
namespace linalg {
namespace impl{

template<class T, class = void>
struct has_imag : std::false_type {};

// If I can find unqualified imag via overload resolution,
// then assume that imag(t) returns the imag part of t.
template<class T>
struct has_imag<T, decltype(imag(std::declval<T>()), void())> : std::true_type {};

template<class T>
T imag_if_needed_impl(const T& t, std::false_type)
{
// If imag(t) can't be ADL-found, then assume
// that T represents a noncomplex number type.
return T{};
}

template<class T>
auto imag_if_needed_impl(const T& t, std::true_type)
{
if constexpr (std::is_arithmetic_v<T>) {
// Overloads for integers have a return type of double.
// We want to preserve the input type T.
return T{};
} else {
return imag(t);
}
}

// Inline static variables require C++17.
constexpr inline auto imag_if_needed = [](const auto& t)
{
using T = std::remove_const_t<decltype(t)>;
return imag_if_needed_impl(t, has_imag<T>{});
};

} // end namespace impl
} // end namespace linalg
} // end inline namespace __p1673_version_0
} // end namespace MDSPAN_IMPL_PROPOSED_NAMESPACE
} // end namespace MDSPAN_IMPL_STANDARD_NAMESPACE

#endif //LINALG_INCLUDE_EXPERIMENTAL___P1673_BITS_IMAG_IF_NEEDED_HPP_
71 changes: 2 additions & 69 deletions include/experimental/__p1673_bits/proxy_reference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
#ifndef LINALG_INCLUDE_EXPERIMENTAL___P1673_BITS_PROXY_REFERENCE_HPP_
#define LINALG_INCLUDE_EXPERIMENTAL___P1673_BITS_PROXY_REFERENCE_HPP_

#include "conjugate_if_needed.hpp"
#if defined(__cpp_lib_atomic_ref) && defined(LINALG_ENABLE_ATOMIC_REF)
# include <atomic>
#endif
Expand All @@ -67,72 +66,6 @@ template<class U>
static constexpr bool is_atomic_ref_not_arithmetic_v<std::atomic_ref<U>> = ! std::is_arithmetic_v<U>;
#endif

template<class T, class = void>
struct has_imag : std::false_type {};

// If I can find unqualified imag via overload resolution,
// then assume that imag(t) returns the imaginary part of t.
template<class T>
struct has_imag<T, decltype(imag(std::declval<T>()), void())> : std::true_type {};

template<class T>
T imag_part_impl(const T& t, std::false_type)
{
return T{};
}

template<class T>
auto imag_part_impl(const T& t, std::true_type)
{
if constexpr (std::is_arithmetic_v<T>) {
return T{};
} else {
return imag(t);
}
}

template<class T>
auto imag_part(const T& t)
{
return imag_part_impl(t, has_imag<T>{});
}

template<class T, class = void>
struct has_real : std::false_type {};

// If I can find unqualified real via overload resolution,
// then assume that real(t) returns the real part of t.
template<class T>
struct has_real<T, decltype(real(std::declval<T>()), void())> : std::true_type {};

template<class T>
T real_part_impl(const T& t, std::false_type)
{
return t;
}

template<class T>
auto real_part_impl(const T& t, std::true_type)
{
if constexpr (std::is_arithmetic_v<T>) {
return t;
} else {
return real(t);
}
}

template<class T>
auto real_part(const T& t)
{
return real_part_impl(t, has_real<T>{});
}

// template<class R>
// R imag_part(const std::complex<R>& z)
// {
// return std::imag(z);
// }

// A "tag" for identifying the proxy reference types in this proposal.
// It's helpful for this tag to be a complete type, so that we can use
// it inside proxy_reference (proxy_reference isn't really complete
Expand Down Expand Up @@ -240,11 +173,11 @@ class proxy_reference : proxy_reference_base {
}

friend auto real(const derived_type& x) {
return real_part(value_type(static_cast<const this_type&>(x)));
return impl::real_if_needed(value_type(static_cast<const this_type&>(x)));
}

friend auto imag(const derived_type& x) {
return imag_part(value_type(static_cast<const this_type&>(x)));
return impl::imag_if_needed(value_type(static_cast<const this_type&>(x)));
}

friend auto conj(const derived_type& x) {
Expand Down
Loading