Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(array): Add Presto function array_top_n #12105

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
#include "velox/expression/PrestoCastHooks.h"
#include "velox/functions/Udf.h"
#include "velox/functions/lib/CheckedArithmetic.h"
#include "velox/functions/lib/ComparatorUtil.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/type/Conversions.h"
#include "velox/type/FloatingPointUtil.h"

#include <queue>

namespace facebook::velox::functions {

template <typename TExecCtx, bool isMax>
Expand Down Expand Up @@ -729,6 +732,153 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
}
}

/// This class implements the array_top_n function.
///
/// DEFINITION:
/// array_top_n(array(T), int) -> array(T)
/// Returns the top n elements of the array in descending order.
template <typename T>
struct ArrayTopNFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Definition for primitives.
template <typename TReturn, typename TInput>
FOLLY_ALWAYS_INLINE void
call(TReturn& result, const TInput& array, int32_t n) {
VELOX_USER_CHECK_GE(n, 0, "Parameter n: {} to ARRAY_TOP_N is negative", n);

// If top n is zero or input array is empty then exit early.
if (n == 0 || array.size() == 0) {
return;
}

// Define comparator that wraps built-in function for basic primitives or
// calls floating point handler for NaNs.
using facebook::velox::util::floating_point::NaNAwareGreaterThan;
struct GreaterThanComparator {
bool operator()(
const typename TInput::element_t& a,
const typename TInput::element_t& b) const {
if constexpr (
std::is_same_v<typename TInput::element_t, float> ||
std::is_same_v<typename TInput::element_t, double>) {
return NaNAwareGreaterThan<typename TInput::element_t>{}(a, b);
} else {
return std::greater<typename TInput::element_t>{}(a, b);
}
}
};

// Define min-heap to store the top n elements.
std::priority_queue<
typename TInput::element_t,
std::vector<typename TInput::element_t>,
GreaterThanComparator>
minHeap;

// Iterate through the array and push elements to the min-heap.
GreaterThanComparator comparator;
int numNull = 0;
for (const auto& item : array) {
if (item.has_value()) {
if (minHeap.size() < n) {
minHeap.push(item.value());
} else if (comparator(item.value(), minHeap.top())) {
minHeap.push(item.value());
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<typename TInput::element_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& item : reversed) {
result.push_back(item);
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}

// Generic implementation.
FOLLY_ALWAYS_INLINE void call(
out_type<Array<Orderable<T1>>>& result,
const arg_type<Array<Orderable<T1>>>& array,
const int32_t n) {
VELOX_USER_CHECK_GE(n, 0, "Parameter n: {} to ARRAY_TOP_N is negative", n);

// If top n is zero or input array is empty then exit early.
if (n == 0 || array.size() == 0) {
return;
}

// Define comparator to compare complex types.
struct ComplexTypeComparator {
const arg_type<Array<Orderable<T1>>>& array;
ComplexTypeComparator(const arg_type<Array<Orderable<T1>>>& array)
: array(array) {}

bool operator()(const int32_t& a, const int32_t& b) const {
static constexpr CompareFlags kFlags = {
.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
return array[a].value().compare(array[b].value(), kFlags).value() > 0;
}
};

// Define min-heap to store the top n elements.
std::priority_queue<int32_t, std::vector<int32_t>, ComplexTypeComparator>
minHeap(array);

// Iterate through the array and push elements to the min-heap.
ComplexTypeComparator comparator(array);
int numNull = 0;
for (int i = 0; i < array.size(); ++i) {
if (array[i].has_value()) {
if (minHeap.size() < n) {
minHeap.push(i);
} else if (comparator(i, minHeap.top())) {
minHeap.push(i);
minHeap.pop();
}
} else {
++numNull;
}
}

// Reverse the min-heap to get the top n elements in descending order.
std::vector<int32_t> reversed(minHeap.size());
auto index = minHeap.size();
while (!minHeap.empty()) {
reversed[--index] = minHeap.top();
minHeap.pop();
}

// Copy mutated vector to result vector up to minHeap's size items.
for (const auto& index : reversed) {
result.push_back(array[index].value());
}

// Backfill nulls if needed.
while (result.size() < n && numNull > 0) {
result.add_null();
--numNull;
}
}
};

template <typename T>
struct ArrayTrimFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) {
{prefix + "trim_array"});
}

template <typename T>
inline void registerArrayTopNFunction(const std::string& prefix) {
registerFunction<ArrayTopNFunction, Array<T>, Array<T>, int32_t>(
{prefix + "array_top_n"});
}

template <typename T>
inline void registerArrayRemoveNullFunctions(const std::string& prefix) {
registerFunction<ArrayRemoveNullFunction, Array<T>, Array<T>>(
Expand Down Expand Up @@ -241,6 +247,19 @@ void registerArrayFunctions(const std::string& prefix) {
Array<Varchar>,
int64_t>({prefix + "trim_array"});

registerArrayTopNFunction<int8_t>(prefix);
registerArrayTopNFunction<int16_t>(prefix);
registerArrayTopNFunction<int32_t>(prefix);
registerArrayTopNFunction<int64_t>(prefix);
registerArrayTopNFunction<int128_t>(prefix);
registerArrayTopNFunction<float>(prefix);
registerArrayTopNFunction<double>(prefix);
registerArrayTopNFunction<Varchar>(prefix);
registerArrayTopNFunction<Timestamp>(prefix);
registerArrayTopNFunction<Date>(prefix);
registerArrayTopNFunction<Varbinary>(prefix);
registerArrayTopNFunction<Orderable<T1>>(prefix);

registerArrayRemoveNullFunctions<int8_t>(prefix);
registerArrayRemoveNullFunctions<int16_t>(prefix);
registerArrayRemoveNullFunctions<int32_t>(prefix);
Expand Down
Loading
Loading