Skip to content

Commit

Permalink
InterpolatedShape::supportLoop (#19)
Browse files Browse the repository at this point in the history
* define support index type in BasisCombination, define a supportLoop in the interpolated shape

* add BasisProduct::supportLoop to avoid the creation of the std::integer_sequence< int, BASIS_TYPES...> basisSupportCounts{}

* remove commented out code

---------

Co-authored-by: Randolph Settgast <[email protected]>
  • Loading branch information
wrtobin and rrsettgast authored Nov 14, 2023
1 parent a24d221 commit 328fb97
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 34 deletions.
5 changes: 4 additions & 1 deletion src/common/NestedSequenceUtilities.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#pragma once

#include "common/ShivaMacros.hpp"
#include <type_traits>
#include <utility>
Expand Down Expand Up @@ -53,7 +55,8 @@ constexpr void forNestedSequence( FUNC && func )
}

template< int... ENDS, typename FUNC >
constexpr void forNestedSequence( std::integer_sequence<int, ENDS...>, FUNC && func )
constexpr void forNestedSequence( std::integer_sequence<int, ENDS...>,
FUNC && func )
{
forNestedSequence< ENDS... >( std::forward<FUNC>( func ) );
}
Expand Down
17 changes: 15 additions & 2 deletions src/functions/bases/BasisProduct.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "common/SequenceUtilities.hpp"
#include "common/NestedSequenceUtilities.hpp"
#include "common/MultiIndex.hpp"

namespace shiva
{
Expand All @@ -26,7 +27,19 @@ struct BasisProduct

/// Alias for the type that represents a coordinate
using CoordType = REAL_TYPE[numDims];


using IndexType = typename SequenceAlias< MultiIndexRangeI,
std::integer_sequence< int, BASIS_TYPES::numSupportPoints... > >::type;


template < typename FUNC >
SHIVA_STATIC_CONSTEXPR_HOSTDEVICE_FORCEINLINE void
supportLoop( FUNC && func )
{
forNestedSequence<BASIS_TYPES::numSupportPoints...>( std::forward< FUNC >( func ) );
}



/**
* @brief Calculates the value of the basis function at the specified parent
Expand Down
32 changes: 8 additions & 24 deletions src/geometry/mapping/LinearTransform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class LinearTransform
using CoordType = REAL_TYPE[numDims];

/// The type used to represent the index space of the cell
using IndexType = typename SequenceAlias< MultiIndexRangeI, decltype(InterpolatedShape::basisSupportCounts) >::type;
using SupportIndexType = typename InterpolatedShape::BasisCombinationType::IndexType;
// using IndexType = typename SequenceAlias< MultiIndexRangeI, decltype(InterpolatedShape::basisSupportCounts) >::type;

/**
* @brief Returns a boolean indicating whether the Jacobian is constant in the
Expand All @@ -80,7 +81,7 @@ class LinearTransform
* @brief Provides non-const access to member data through reference.
* @return a mutable reference to the member data.
*/
constexpr SHIVA_HOST_DEVICE SHIVA_FORCE_INLINE DataType & setData() { return m_vertexCoords; }
constexpr SHIVA_HOST_DEVICE SHIVA_FORCE_INLINE DataType & getData() { return m_vertexCoords; }


/**
Expand All @@ -98,23 +99,6 @@ class LinearTransform
}
}


/**
* @brief method to loop over the vertices of the cuboid
* @tparam FUNCTION_TYPE The type of the function to execute
* @param[in] func The function to execute
*/
template< typename FUNCTION_TYPE >
constexpr SHIVA_HOST_DEVICE SHIVA_FORCE_INLINE void forVertices( FUNCTION_TYPE && func ) const
{
IndexType index{0, 0, 0};

forRange( index, [this, func] ( auto const & i )
{
func( i, m_vertexCoords[linearIndex( i )] );
} );
}

private:
/// Data member that stores the vertex coordinates of the cuboid
DataType m_vertexCoords;
Expand All @@ -132,8 +116,7 @@ namespace utilities
*/
template< typename REAL_TYPE, typename INTERPOLATED_SHAPE >
SHIVA_STATIC_CONSTEXPR_HOSTDEVICE_FORCEINLINE void jacobian( LinearTransform< REAL_TYPE, INTERPOLATED_SHAPE > const &,//cell,
typename LinearTransform< REAL_TYPE, INTERPOLATED_SHAPE >::JacobianType::type & )//J
// )
typename LinearTransform< REAL_TYPE, INTERPOLATED_SHAPE >::JacobianType::type & )
{}

/**
Expand All @@ -152,17 +135,18 @@ jacobian( LinearTransform< REAL_TYPE, INTERPOLATED_SHAPE > const & cell,
typename LinearTransform< REAL_TYPE, INTERPOLATED_SHAPE >::JacobianType::type & J )
{
using Transform = std::remove_reference_t<decltype(cell)>;
using IndexType = typename Transform::IndexType;
using InterpolatedShape = typename Transform::InterpolatedShape;
using IndexType = typename InterpolatedShape::BasisCombinationType::IndexType;
constexpr int DIMS = Transform::numDims;

auto const & nodeCoords = cell.getData();
forNestedSequence( InterpolatedShape::basisSupportCounts,
InterpolatedShape::template supportLoop(
[&] ( auto const ... icNa ) constexpr
{
IndexType index{ { decltype(icNa)::value... } };
CArray1d< REAL_TYPE, DIMS > const dNadXi = INTERPOLATED_SHAPE::template gradient< decltype(icNa)::value... >( pointCoordsParent );
CArray1d< REAL_TYPE, DIMS > const dNadXi = InterpolatedShape::template gradient< decltype(icNa)::value... >( pointCoordsParent );
auto const & nodeCoord = nodeCoords[ flattenIndex( index ) ];
// dimensional loop from domain to codomain
forNestedSequence< DIMS, DIMS >(
[&] ( auto const ... icijk ) constexpr
{
Expand Down
6 changes: 3 additions & 3 deletions src/geometry/mapping/unitTests/testLinearTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ SHIVA_HOST_DEVICE auto makeLinearTransform( REAL_TYPE const (&X)[8][3] )
LagrangeBasis< double, 1, EqualSpacing >,
LagrangeBasis< double, 1, EqualSpacing > > > cell;

typename decltype(cell)::IndexType index;
typename decltype(cell)::SupportIndexType index;

auto & transformData = cell.setData();
auto & transformData = cell.getData();

forRange( index = {0, 0, 0}, [&transformData, &X] ( auto const & i )
{
Expand All @@ -121,7 +121,7 @@ void testConstructionAndSettersHelper()
pmpl::genericKernelWrapper( 8 * 3, data, [] SHIVA_DEVICE ( double * const kernelData )
{
auto const cell = makeLinearTransform( Xref );
typename decltype(cell)::IndexType index{0, 0, 0};
typename decltype(cell)::SupportIndexType index{0, 0, 0};

auto const & transformData = cell.getData();

Expand Down
14 changes: 10 additions & 4 deletions src/geometry/shapes/InterpolatedShape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,24 @@ class InterpolatedShape
using CoordType = typename StandardGeom::CoordType;

/// The type used to represent the product of basis functions
using BASIS_PRODUCT_TYPE = functions::BasisProduct< REAL_TYPE, BASIS_TYPE... >;
using BasisCombinationType = functions::BasisProduct< REAL_TYPE, BASIS_TYPE... >;

/// The number of dimensions on the InterpolatedShape
static inline constexpr int numDims = sizeof...(BASIS_TYPE);

/// The number of vertices on the InterpolatedShape
static inline constexpr int numVertices = StandardGeom::numVertices();

static inline constexpr std::integer_sequence< int, BASIS_TYPE::numSupportPoints... > basisSupportCounts{};
// static inline constexpr std::integer_sequence< int, BASIS_TYPE::numSupportPoints... > basisSupportCounts{};

static_assert( numDims == StandardGeom::numDims(), "numDims mismatch between cell and number of basis specified" );

template < typename FUNC >
SHIVA_STATIC_CONSTEXPR_HOSTDEVICE_FORCEINLINE void
supportLoop( FUNC && func )
{
BasisCombinationType::supportLoop( std::forward< FUNC >( func ) );
}

/**
* @copydoc functions::BasisProduct::value
Expand All @@ -63,7 +69,7 @@ class InterpolatedShape
value( CoordType const & parentCoord )
{
static_assert( sizeof...(BASIS_FUNCTION_INDICES) == numDims, "Wrong number of basis function indicies specified" );
return ( BASIS_PRODUCT_TYPE::template value< BASIS_FUNCTION_INDICES... >( parentCoord ) );
return ( BasisCombinationType::template value< BASIS_FUNCTION_INDICES... >( parentCoord ) );
}

/**
Expand All @@ -74,7 +80,7 @@ class InterpolatedShape
gradient( CoordType const & parentCoord )
{
static_assert( sizeof...(BASIS_FUNCTION_INDICES) == numDims, "Wrong number of basis function indicies specified" );
return ( BASIS_PRODUCT_TYPE::template gradient< BASIS_FUNCTION_INDICES... >( parentCoord ) );
return ( BasisCombinationType::template gradient< BASIS_FUNCTION_INDICES... >( parentCoord ) );
}
};

Expand Down

0 comments on commit 328fb97

Please sign in to comment.