Skip to content

Commit

Permalink
[oneDPL] Fix performance issue in __serial_merge (#2022)
Browse files Browse the repository at this point in the history
  • Loading branch information
SergeyKopienko authored Jan 27, 2025
1 parent 1746d46 commit 2a3a0b8
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
#include <limits> // std::numeric_limits
#include <cassert> // assert
#include <cstdint> // std::uint8_t, ...
#include <utility> // std::make_pair, std::forward
#include <utility> // std::make_pair, std::forward, std::declval
#include <algorithm> // std::min, std::lower_bound
#include <type_traits> // std::void_t, std::true_type, std::false_type

#include "sycl_defs.h"
#include "parallel_backend_sycl_utils.h"
Expand Down Expand Up @@ -130,6 +131,21 @@ __find_start_point(const _Rng1& __rng1, const _Index __rng1_from, _Index __rng1_
return _split_point_t<_Index>{*__res, __index_sum - *__res + 1};
}

template <typename _Rng1DataType, typename _Rng2DataType, typename = void>
struct __can_use_ternary_op : std::false_type
{
};

template <typename _Rng1DataType, typename _Rng2DataType>
struct __can_use_ternary_op<_Rng1DataType, _Rng2DataType,
std::void_t<decltype(true ? std::declval<_Rng1DataType>() : std::declval<_Rng2DataType>())>>
: std::true_type
{
};

template <typename _Rng1DataType, typename _Rng2DataType>
constexpr static bool __can_use_ternary_op_v = __can_use_ternary_op<_Rng1DataType, _Rng2DataType>::value;

// Do serial merge of the data from rng1 (starting from start1) and rng2 (starting from start2) and writing
// to rng3 (starting from start3) in 'chunk' steps, but do not exceed the total size of the sequences (n1 and n2)
template <typename _Rng1, typename _Rng2, typename _Rng3, typename _Index, typename _Compare>
Expand All @@ -156,11 +172,23 @@ __serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, const _I
// One of __rng1_idx_less_n1 and __rng2_idx_less_n2 should be true here
// because 1) we should fill output data with elements from one of the input ranges
// 2) we calculate __rng3_idx_end as std::min<_Index>(__rng1_size + __rng2_size, __chunk).
if (__rng1_idx_less_n1 && __rng2_idx_less_n2 && __comp(__rng2[__rng2_idx], __rng1[__rng1_idx]) ||
!__rng1_idx_less_n1)
__rng3[__rng3_idx] = __rng2[__rng2_idx++];
if constexpr (__can_use_ternary_op_v<decltype(__rng1[__rng1_idx]), decltype(__rng2[__rng2_idx])>)
{
// This implementation is required for performance optimization
__rng3[__rng3_idx] = (!__rng1_idx_less_n1 || __rng1_idx_less_n1 && __rng2_idx_less_n2 &&
__comp(__rng2[__rng2_idx], __rng1[__rng1_idx]))
? __rng2[__rng2_idx++]
: __rng1[__rng1_idx++];
}
else
__rng3[__rng3_idx] = __rng1[__rng1_idx++];
{
// TODO required to understand why the usual if-else is slower then ternary operator
if (!__rng1_idx_less_n1 ||
__rng1_idx_less_n1 && __rng2_idx_less_n2 && __comp(__rng2[__rng2_idx], __rng1[__rng1_idx]))
__rng3[__rng3_idx] = __rng2[__rng2_idx++];
else
__rng3[__rng3_idx] = __rng1[__rng1_idx++];
}
}
}

Expand Down

0 comments on commit 2a3a0b8

Please sign in to comment.