Skip to content

Commit

Permalink
Merge pull request #786 from kordejong/gh785
Browse files Browse the repository at this point in the history
  • Loading branch information
kordejong authored Jan 21, 2025
2 parents f888adb + a5e4322 commit a62f4cb
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 70 deletions.
2 changes: 1 addition & 1 deletion source/framework/algorithm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ block()

foreach(Policies IN LISTS LUE_FRAMEWORK_ALGORITHM_POLICIES)
foreach(InputElement IN LISTS LUE_FRAMEWORK_INTEGRAL_ELEMENTS)
foreach(OutputElement IN LISTS LUE_FRAMEWORK_FLOATING_POINT_ELEMENTS)
foreach(OutputElement IN LISTS LUE_FRAMEWORK_ELEMENTS)
foreach(rank IN LISTS LUE_FRAMEWORK_RANKS)
math(EXPR count "${count} + 1")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ namespace lue {
namespace default_policies {

template<typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table)
-> PartitionedArray<ToElement, rank>
{
using Policies = policy::reclassify::DefaultPolicies<FromElement, ToElement>;

Expand All @@ -28,9 +29,9 @@ namespace lue {


template<typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table)
LookupTable<FromElement, ToElement> const& lookup_table) -> PartitionedArray<ToElement, rank>
{
using Policies = policy::reclassify::DefaultPolicies<FromElement, ToElement>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ namespace lue {
namespace detail {

template<typename OutputPartition, typename Policies, typename InputPartition>
OutputPartition reclassify_partition_ready(
auto reclassify_partition_ready(
Policies const& policies,
InputPartition const& input_partition,
LookupTable<ElementT<InputPartition>, ElementT<OutputPartition>> const& lookup_table)
-> OutputPartition
{
using Offset = OffsetT<InputPartition>;
using InputData = DataT<InputPartition>;
Expand Down Expand Up @@ -68,11 +69,11 @@ namespace lue {


template<typename Policies, typename InputPartition, typename OutputPartition>
OutputPartition reclassify_partition(
auto reclassify_partition(
Policies const& policies,
InputPartition const& input_partition,
hpx::shared_future<LookupTable<ElementT<InputPartition>, ElementT<OutputPartition>>> const&
lookup_table)
lookup_table) -> OutputPartition
{
using FromElement = ElementT<InputPartition>;
using ToElement = ElementT<OutputPartition>;
Expand Down Expand Up @@ -105,10 +106,11 @@ namespace lue {


template<typename Policies, typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
Policies const& policies,
PartitionedArray<FromElement, rank> const& input_array,
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table)
-> PartitionedArray<ToElement, rank>
{
// Spawn a task for each partition that will reclassify it

Expand All @@ -128,10 +130,10 @@ namespace lue {
InputPartitions const& input_partitions{input_array.partitions()};
OutputPartitions output_partitions{shape_in_partitions(input_array)};

for (Index p = 0; p < nr_partitions(input_array); ++p)
for (Index partition_idx = 0; partition_idx < nr_partitions(input_array); ++partition_idx)
{
output_partitions[p] =
hpx::async(action, localities[p], policies, input_partitions[p], lookup_table);
output_partitions[partition_idx] = hpx::async(
action, localities[partition_idx], policies, input_partitions[partition_idx], lookup_table);
}

return OutputArray{shape(input_array), localities, std::move(output_partitions)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ namespace lue {


template<typename Policies, typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
Policies const& policies,
PartitionedArray<FromElement, rank> const& array,
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table);
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table)
-> PartitionedArray<ToElement, rank>;


template<typename Policies, typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
Policies const& policies,
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table)
LookupTable<FromElement, ToElement> const& lookup_table) -> PartitionedArray<ToElement, rank>
{
return reclassify(policies, array, hpx::make_ready_future(lookup_table).share());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ namespace lue {
namespace value_policies {

template<typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
hpx::shared_future<LookupTable<FromElement, ToElement>> const& lookup_table)
-> PartitionedArray<ToElement, rank>
{
using Policies = policy::reclassify::DefaultValuePolicies<FromElement, ToElement>;

Expand All @@ -28,9 +29,9 @@ namespace lue {


template<typename FromElement, typename ToElement, Rank rank>
PartitionedArray<ToElement, rank> reclassify(
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table)
LookupTable<FromElement, ToElement> const& lookup_table) -> PartitionedArray<ToElement, rank>
{
using Policies = policy::reclassify::DefaultValuePolicies<FromElement, ToElement>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,48 @@ using namespace pybind11::literals;
namespace lue::framework {
namespace {

template<typename FromElement, typename ToElement, Rank rank>
auto reclassify2(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table) -> PartitionedArray<ToElement, rank>
template<typename ToElement, std::integral FromElement, std::floating_point LUTElement>
auto cast_lut(LookupTable<FromElement, LUTElement> const& lookup_table)
-> LookupTable<FromElement, ToElement>
{
return value_policies::reclassify(array, lookup_table);
}


template<typename ToElement2, typename FromElement, typename ToElement1>
auto cast_lut(LookupTable<FromElement, ToElement1> const& lookup_table)
-> LookupTable<FromElement, ToElement2>
{
static_assert(std::is_integral_v<FromElement>);
static_assert(std::is_floating_point_v<ToElement1>);
static_assert(std::is_floating_point_v<ToElement2>);

if constexpr (std::is_same_v<ToElement1, ToElement2>)
if constexpr (std::is_same_v<LUTElement, ToElement>)
{
return lookup_table;
}
else
{
LookupTable<FromElement, ToElement2> result;
LookupTable<FromElement, ToElement> result;

for (auto const& [key, value] : lookup_table)
{
result[key] = static_cast<ToElement2>(value);
result[key] = static_cast<ToElement>(value);
}

return result;
}
}


template<typename FromElement, typename ToElement, Rank rank>
template<Arithmetic ToElement, std::integral FromElement, std::floating_point LUTElement, Rank rank>
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table,
LookupTable<FromElement, LUTElement> const& lookup_table) -> pybind11::object
{
pybind11::object result{};

if constexpr (arithmetic_element_supported<ToElement>)
{
result = pybind11::cast(value_policies::reclassify(array, cast_lut<ToElement>(lookup_table)));
}

return result;
}


template<std::integral FromElement, std::floating_point LUTElement, Rank rank>
auto reclassify(
PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, LUTElement> const& lookup_table,
pybind11::dtype const& dtype) -> pybind11::object
{
// Switch on dtype and call a function that returns an array of the
Expand All @@ -65,12 +68,19 @@ namespace lue::framework {
// Signed integer
switch (size)
{
case 1:
{
result = reclassify<std::int8_t>(array, lookup_table);
break;
}
case 4:
{
result = reclassify<std::int32_t>(array, lookup_table);
break;
}
case 8:
{
result = reclassify<std::int64_t>(array, lookup_table);
break;
}
}
Expand All @@ -84,14 +94,17 @@ namespace lue::framework {
{
case 1:
{
result = reclassify<std::uint8_t>(array, lookup_table);
break;
}
case 4:
{
result = reclassify<std::uint32_t>(array, lookup_table);
break;
}
case 8:
{
result = reclassify<std::uint64_t>(array, lookup_table);
break;
}
}
Expand All @@ -105,24 +118,12 @@ namespace lue::framework {
{
case 4:
{
using Element = float;

if constexpr (arithmetic_element_supported<Element>)
{
result = pybind11::cast(reclassify2(array, cast_lut<Element>(lookup_table)));
}

result = reclassify<float>(array, lookup_table);
break;
}
case 8:
{
using Element = double;

if constexpr (arithmetic_element_supported<Element>)
{
result = pybind11::cast(reclassify2(array, cast_lut<Element>(lookup_table)));
}

result = reclassify<double>(array, lookup_table);
break;
}
}
Expand All @@ -140,16 +141,6 @@ namespace lue::framework {
}


// template<typename FromElement, typename ToElement, Rank rank>
// auto reclassify(
// PartitionedArray<FromElement, rank> const& array,
// LookupTable<FromElement, ToElement> const& lookup_table,
// pybind11::object const& dtype_args) -> pybind11::object
// {
// return reclassify1(array, lookup_table, pybind11::dtype::from_args(dtype_args));
// }


class Binder
{

Expand All @@ -160,12 +151,12 @@ namespace lue::framework {
{
Rank const rank{2};
using FromElement = Element;
using ToElement = LargestFloatingPointElement;
using LUTElement = LargestFloatingPointElement;

module.def(
"reclassify",
[](PartitionedArray<FromElement, rank> const& array,
LookupTable<FromElement, ToElement> const& lookup_table,
LookupTable<FromElement, LUTElement> const& lookup_table,
pybind11::object const& dtype_args) // -> pybind11::object
{ return reclassify(array, lookup_table, pybind11::dtype::from_args(dtype_args)); },
"array"_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def test_overloads(self):
for from_element_type in lfr.integral_element_types:
ids = lfr.create_array(array_shape, from_element_type, id_)

for to_element_type in lfr.floating_point_element_types:
for to_element_type in lfr.arithmetic_element_types:
lookup_table = {
1: 1.1,
2: 2.2,
3: 3.3,
4: 4.4,
1: 4,
2: 3,
3: 2,
4: 1,
}
array = lfr.reclassify(ids, lookup_table, dtype=to_element_type)

Expand Down

0 comments on commit a62f4cb

Please sign in to comment.