Skip to content

Commit

Permalink
differentiate between linearIndex and index flattening operation
Browse files Browse the repository at this point in the history
  • Loading branch information
wrtobin committed Nov 10, 2023
1 parent 74c2f73 commit a24d221
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/common/IndexTypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ forRangeHelper( MultiIndexRange< BASE_INDEX_TYPE, RANGES... > const & start,
*/
template< typename BASE_INDEX_TYPE, BASE_INDEX_TYPE... RANGES >
SHIVA_CONSTEXPR_HOSTDEVICE_FORCEINLINE BASE_INDEX_TYPE
linearIndex( MultiIndexRange< BASE_INDEX_TYPE, RANGES... > const & index )
flattenIndex( MultiIndexRange< BASE_INDEX_TYPE, RANGES... > const & index )
{
using IndexType = MultiIndexRange< BASE_INDEX_TYPE, RANGES... >;
return detail::linearIndexHelper( index, std::make_integer_sequence< int, IndexType::NUM_INDICES >{} );
Expand All @@ -135,7 +135,7 @@ linearIndex( MultiIndexRange< BASE_INDEX_TYPE, RANGES... > const & index )
*/
template< typename INDEX_TYPE >
SHIVA_CONSTEXPR_HOSTDEVICE_FORCEINLINE INDEX_TYPE
linearIndex( LinearIndex< INDEX_TYPE > const & index )
flattenIndex( LinearIndex< INDEX_TYPE > const & index )
{
return index;
}
Expand Down
6 changes: 3 additions & 3 deletions src/common/unitTests/testIndexTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ void testLinearIndexTypeHelper()
LinearIndex< int > a = 0;
for ( a = 0, i = 0; a < 10; ++a, ++i )
{
kdata[i] = linearIndex( a );
kdata[i] = flattenIndex( a );
}
} );
for ( int i = 0; i < 10; ++i )
Expand Down Expand Up @@ -50,7 +50,7 @@ void testMultiIndexManualLoopHelper()
{
for ( c = 0; c < 2; ++c )
{
kdata[4 * a + 2 * b + c] = linearIndex( index );
kdata[4 * a + 2 * b + c] = flattenIndex( index );
}
}
}
Expand Down Expand Up @@ -83,7 +83,7 @@ void testMultiIndexForRangeHelper()

forRange( index, [&] ( auto const & i )
{
kdata[4 * i.data[0] + 2 * i.data[1] + i.data[2]] = linearIndex( i );
kdata[4 * i.data[0] + 2 * i.data[1] + i.data[2]] = flattenIndex( i );
} );
} );

Expand Down
6 changes: 3 additions & 3 deletions src/geometry/mapping/LinearTransform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,18 @@ jacobian( LinearTransform< REAL_TYPE, INTERPOLATED_SHAPE > const & cell,
using InterpolatedShape = typename Transform::InterpolatedShape;
constexpr int DIMS = Transform::numDims;

auto const & vertexCoords = cell.getData();
auto const & nodeCoords = cell.getData();
forNestedSequence( InterpolatedShape::basisSupportCounts,
[&] ( auto const ... icNa ) constexpr
{
IndexType index{ { decltype(icNa)::value... } };
CArray1d< REAL_TYPE, DIMS > const dNadXi = INTERPOLATED_SHAPE::template gradient< decltype(icNa)::value... >( pointCoordsParent );
auto const & vertexCoord = vertexCoords[ linearIndex( index ) ];
auto const & nodeCoord = nodeCoords[ flattenIndex( index ) ];
forNestedSequence< DIMS, DIMS >(
[&] ( auto const ... icijk ) constexpr
{
constexpr int ijk[DIMS] = { decltype(icijk)::value... };
J[ijk[1]][ijk[0]] = J[ijk[1]][ijk[0]] + dNadXi[ijk[0]] * vertexCoord[ijk[1]];
J[ijk[1]][ijk[0]] = J[ijk[1]][ijk[0]] + dNadXi[ijk[0]] * nodeCoord[ijk[1]];
} );
}
);
Expand Down
4 changes: 2 additions & 2 deletions src/geometry/mapping/unitTests/testLinearTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ SHIVA_HOST_DEVICE auto makeLinearTransform( REAL_TYPE const (&X)[8][3] )

for ( int j = 0; j < 3; ++j )
{
transformData[ linearIndex( i ) ][j] = X[ a + 2 * b + 4 * c ][j];
transformData[ flattenIndex( i ) ][j] = X[ a + 2 * b + 4 * c ][j];
}
} );

Expand All @@ -133,7 +133,7 @@ void testConstructionAndSettersHelper()

for ( int j = 0; j < 3; ++j )
{
kernelData[ 3 * ( a + 2 * b + 4 * c ) + j ] = transformData[linearIndex( i )][j];
kernelData[ 3 * ( a + 2 * b + 4 * c ) + j ] = transformData[flattenIndex( i )][j];
}
} );
} );
Expand Down
14 changes: 7 additions & 7 deletions src/geometry/shapes/InterpolatedShape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@ namespace geometry
* @brief Defines a class that provides static functions to calculate quantities
* required from the parent element in a finite element method.
* @tparam REAL_TYPE The floating point type to use
* @tparam BASE_SHAPE The cell type/geometry
* @tparam STANDARD_GEOMETRY The standard geometric form of the interpolated shape (domain)
* @tparam FUNCTIONAL_SPACE_TYPE The functional space type
* @tparam BASIS_TYPE Pack of basis types to apply to each direction of the
* parent element. There should be a basis defined for each direction.
*/
template< typename REAL_TYPE, typename BASE_SHAPE, typename ... BASIS_TYPE >
template< typename REAL_TYPE, typename STANDARD_GEOMETRY, typename ... BASIS_TYPE >
class InterpolatedShape
{

public:

/// The type used to represent the cell/geometry
using BaseShape = BASE_SHAPE;
using StandardGeom = STANDARD_GEOMETRY;
// using FunctionalSpaceType = FUNCTIONAL_SPACE_TYPE;
// using IndexType = typename BaseShape::IndexType;
// using IndexType = typename Geometry::IndexType;

/// Alias for the floating point type
using RealType = REAL_TYPE;

/// Alias for the type that represents a coordinate
using CoordType = typename BaseShape::CoordType;
using CoordType = typename StandardGeom::CoordType;

/// The type used to represent the product of basis functions
using BASIS_PRODUCT_TYPE = functions::BasisProduct< REAL_TYPE, BASIS_TYPE... >;
Expand All @@ -48,11 +48,11 @@ class InterpolatedShape
static inline constexpr int numDims = sizeof...(BASIS_TYPE);

/// The number of vertices on the InterpolatedShape
static inline constexpr int numVertices = BaseShape::numVertices();
static inline constexpr int numVertices = StandardGeom::numVertices();

static inline constexpr std::integer_sequence< int, BASIS_TYPE::numSupportPoints... > basisSupportCounts{};

static_assert( numDims == BaseShape::numDims(), "numDims mismatch between cell and number of basis specified" );
static_assert( numDims == StandardGeom::numDims(), "numDims mismatch between cell and number of basis specified" );


/**
Expand Down

0 comments on commit a24d221

Please sign in to comment.