Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Spark ArraySort with lambda function #10138

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions velox/docs/functions/spark/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,36 @@ Array Functions
.. spark:function:: array_sort(array(E)) -> array(E)

Returns an array which has the sorted order of the input array(E). The elements of array(E) must
be orderable. Null elements will be placed at the end of the returned array. ::
be orderable. NULL and NaN elements will be placed at the end of the returned array, with NaN elements appearing before NULL elements for floating-point types. ::

SELECT array_sort(array(1, 2, 3)); -- [1, 2, 3]
SELECT array_sort(array(3, 2, 1)); -- [1, 2, 3]
SELECT array_sort(array(2, 1, NULL); -- [1, 2, NULL]
SELECT array_sort(array(2, 1, NULL)); -- [1, 2, NULL]
SELECT array_sort(array(NULL, 1, NULL)); -- [1, NULL, NULL]
SELECT array_sort(array(NULL, 2, 1)); -- [1, 2, NULL]
SELECT array_sort(array(4.0, NULL, float('nan'), 3.0)); -- [3.0, 4.0, NaN, NULL]
SELECT array_sort(array(array(), array(1, 3, NULL), array(NULL, 6), NULL, array(2, 1))); -- [[], [NULL, 6], [1, 3, NULL], [2, 1], NULL]

.. spark:function:: array_sort(array(T), function(T,U)) -> array(T)
rui-mo marked this conversation as resolved.
Show resolved Hide resolved
:noindex:

Returns the array sorted by values computed using specified lambda in ascending order. ``U`` must be an orderable type.
NULL and NaN elements returned by the lambda function will be placed at the end of the returned array, with NaN elements appearing before NULL elements.
This function is not supported in Spark and is only used inside Velox for rewriting :spark:func:`array_sort(array(T), function(T,T,U)) -> array(T)` as :spark:func:`array_sort(array(T), function(T,U)) -> array(T)`. ::

.. spark:function:: array_sort(array(T), function(T,T,U)) -> array(T)
rui-mo marked this conversation as resolved.
Show resolved Hide resolved
:noindex:

Returns the array sorted by values computed using specified lambda in ascending
order. ``U`` must be an orderable type.
The function attempts to analyze the lambda function and rewrite it into a simpler call that
rui-mo marked this conversation as resolved.
Show resolved Hide resolved
specifies the sort-by expression (like :spark:func:`array_sort(array(T), function(T,U)) -> array(T)`). For example, ``(left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))`` will be rewritten to ``x -> length(x)``. If rewrite is not possible, a user error will be thrown.
If the rewritten function returns NULL, the corresponding element will be placed at the end.
Please note that due to this rewrite optimization, there is a difference in NULL handling logic between Spark and Velox. In Velox, NULL elements are always placed at the end of the returned array, whereas in Spark, it depends on the comparison logic to compare NULL with other elements. ::

SELECT array_sort(array('cat', 'leopard', 'mouse'), (left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))); -- ['cat', 'mouse', 'leopard']
select array_sort(array("abcd123", "abcd", NULL, "abc"), (left, right) -> if(length(left) > length(right), 1, if(length(left) < length(right), -1, 0))); -- ["abc", "abcd", "abcd123", NULL]
select array_sort(array("abcd123", "abcd", NULL, "abc"), (left, right) -> if(length(left) > length(right), 1, if(length(left) = length(right), 0, -1))); -- ["abc", "abcd", "abcd123", NULL] different with Spark: ["abc", NULL, "abcd", "abcd123"]

.. spark::function:: arrays_zip(array(T), array(U),..) -> array(row(T,U, ...))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@

#include "velox/expression/EvalCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/ArraySort.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/functions/prestosql/SimpleComparisonMatcher.h"
#include "velox/type/FloatingPointUtil.h"

namespace facebook::velox::functions {
Expand All @@ -32,6 +31,7 @@ BufferPtr sortElements(
const ArrayVector& inputArray,
const BaseVector& inputElements,
bool ascending,
bool nullsFirst,
exec::EvalCtx& context,
bool throwOnNestedNull) {
const SelectivityVector inputElementRows =
Expand All @@ -44,7 +44,7 @@ BufferPtr sortElements(
BufferPtr indices = allocateIndices(inputElements.size(), context.pool());
vector_size_t* rawIndices = indices->asMutable<vector_size_t>();

CompareFlags flags{.nullsFirst = false, .ascending = ascending};
CompareFlags flags{.nullsFirst = nullsFirst, .ascending = ascending};
if (throwOnNestedNull) {
flags.nullHandlingMode =
CompareFlags::NullHandlingMode::kNullAsIndeterminate;
Expand All @@ -69,10 +69,10 @@ BufferPtr sortElements(
bool aNull = decodedElements->isNullAt(a);
bool bNull = decodedElements->isNullAt(b);
if (aNull) {
return false;
return nullsFirst;
}
if (bNull) {
return true;
return !nullsFirst;
}

std::optional<int32_t> result = baseElementsVector->compare(
Expand All @@ -93,12 +93,19 @@ void applyComplexType(
const SelectivityVector& rows,
ArrayVector* inputArray,
bool ascending,
bool nullsFirst,
exec::EvalCtx& context,
VectorPtr& resultElements,
bool throwOnNestedNull) {
auto inputElements = inputArray->elements();
auto indices = sortElements(
rows, *inputArray, *inputElements, ascending, context, throwOnNestedNull);
rows,
*inputArray,
*inputElements,
ascending,
nullsFirst,
context,
throwOnNestedNull);
resultElements = BaseVector::transpose(indices, std::move(inputElements));
}

Expand All @@ -122,6 +129,7 @@ void applyScalarType(
const SelectivityVector& rows,
const ArrayVector* inputArray,
bool ascending,
bool nullsFirst,
exec::EvalCtx& context,
VectorPtr& resultElements) {
using T = typename TypeTraits<kind>::NativeType;
Expand Down Expand Up @@ -150,28 +158,39 @@ void applyScalarType(
return;
}
vector_size_t numNulls = 0;
// Move nulls to end of array.
for (vector_size_t i = size - 1; i >= 0; --i) {
if (flatResults->isNullAt(offset + i)) {
swapWithNull<T>(flatResults, offset + size - numNulls - 1, offset + i);
++numNulls;
if (nullsFirst) {
// Move nulls to beginning of array.
for (vector_size_t i = 0; i < size; ++i) {
if (flatResults->isNullAt(offset + i)) {
swapWithNull<T>(flatResults, offset + numNulls, offset + i);
++numNulls;
}
}
} else {
// Move nulls to end of array.
for (vector_size_t i = size - 1; i >= 0; --i) {
if (flatResults->isNullAt(offset + i)) {
swapWithNull<T>(
flatResults, offset + size - numNulls - 1, offset + i);
++numNulls;
}
}
}
// Exclude null values while sorting.
const auto startRow = offset;
const auto startRow = offset + (nullsFirst ? numNulls : 0);
const auto endRow = startRow + size - numNulls;

if constexpr (kind == TypeKind::BOOLEAN) {
uint64_t* rawBits = flatResults->template mutableRawValues<uint64_t>();
const auto numOneBits = bits::countBits(rawBits, startRow, endRow);
const auto endZeroRow = endRow - numOneBits;

if (ascending) {
const auto endZeroRow = endRow - numOneBits;
bits::fillBits(rawBits, startRow, endZeroRow, false);
bits::fillBits(rawBits, endZeroRow, endRow, true);
boneanxs marked this conversation as resolved.
Show resolved Hide resolved
} else {
bits::fillBits(rawBits, startRow, startRow + numOneBits, true);
bits::fillBits(rawBits, endZeroRow, endRow, false);
bits::fillBits(rawBits, startRow + numOneBits, endRow, false);
boneanxs marked this conversation as resolved.
Show resolved Hide resolved
}
} else if constexpr (kind == TypeKind::REAL || kind == TypeKind::DOUBLE) {
T* resultRawValues = flatResults->mutableRawValues();
Expand Down Expand Up @@ -221,8 +240,13 @@ class ArraySortFunction : public exec::VectorFunction {
/// and 'offsets' vectors that control where output arrays start and end
/// remain the same in the output ArrayVector.

explicit ArraySortFunction(bool ascending, bool throwOnNestedNull)
: ascending_{ascending}, throwOnNestedNull_(throwOnNestedNull) {}
explicit ArraySortFunction(
bool ascending,
bool nullsFirst,
bool throwOnNestedNull)
: ascending_{ascending},
nullsFirst_{nullsFirst},
throwOnNestedNull_{throwOnNestedNull} {}

// Execute function.
void apply(
Expand All @@ -231,7 +255,6 @@ class ArraySortFunction : public exec::VectorFunction {
const TypePtr& /*outputType*/,
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK_EQ(args.size(), 1);
auto& arg = args[0];

VectorPtr localResult;
Expand Down Expand Up @@ -272,6 +295,7 @@ class ArraySortFunction : public exec::VectorFunction {
rows,
inputArray,
ascending_,
nullsFirst_,
context,
resultElements);

Expand All @@ -280,6 +304,7 @@ class ArraySortFunction : public exec::VectorFunction {
rows,
inputArray,
ascending_,
nullsFirst_,
context,
resultElements,
throwOnNestedNull_);
Expand All @@ -297,6 +322,7 @@ class ArraySortFunction : public exec::VectorFunction {
}

const bool ascending_;
const bool nullsFirst_;
const bool throwOnNestedNull_;
};

Expand Down Expand Up @@ -354,6 +380,7 @@ class ArraySortLambdaFunction : public exec::VectorFunction {
*flatArray,
*newElements,
ascending_,
false /*nullsFirst*/,
context,
throwOnNestedNull_);
auto sortedElements = BaseVector::wrapInDictionary(
Expand Down Expand Up @@ -386,48 +413,10 @@ template <TypeKind kind>
std::shared_ptr<exec::VectorFunction> createTyped(
const std::vector<exec::VectorFunctionArg>& inputArgs,
bool ascending,
bool nullsFirst,
bool throwOnNestedNull = true) {
VELOX_CHECK_EQ(inputArgs.size(), 1);
rui-mo marked this conversation as resolved.
Show resolved Hide resolved
return std::make_shared<ArraySortFunction<kind>>(
ascending, throwOnNestedNull);
}

// Create function.
std::shared_ptr<exec::VectorFunction> create(
const std::vector<exec::VectorFunctionArg>& inputArgs,
bool ascending,
bool throwOnNestedNull = true) {
if (inputArgs.size() == 2) {
return std::make_shared<ArraySortLambdaFunction>(
ascending, throwOnNestedNull);
}

const auto elementType = inputArgs.front().type->childAt(0);
if (elementType->isUnKnown()) {
return createTyped<TypeKind::UNKNOWN>(
inputArgs, ascending, throwOnNestedNull);
}

return VELOX_DYNAMIC_TYPE_DISPATCH(
createTyped,
elementType->kind(),
inputArgs,
ascending,
throwOnNestedNull);
}

std::shared_ptr<exec::VectorFunction> createAsc(
const std::string& /* name */,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
return create(inputArgs, true, true);
}

std::shared_ptr<exec::VectorFunction> createDesc(
const std::string& /* name */,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
return create(inputArgs, false, true);
ascending, nullsFirst, throwOnNestedNull);
}

// Define function signature.
Expand Down Expand Up @@ -475,11 +464,17 @@ internalCanonicalizeSignatures() {
return signatures;
}

std::shared_ptr<exec::VectorFunction> createAscNoThrowOnNestedNull(
const std::string& /* name */,
std::shared_ptr<exec::VectorFunction> makeArraySortAscNoThrowOnNestedNull(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
return create(inputArgs, true, false);
const core::QueryConfig& config) {
return makeArraySort(
name,
inputArgs,
config,
true,
false /*nullsFirst=*/,
false /*throwOnNestedNull=*/);
}

core::CallTypedExprPtr asArraySortCall(
Expand All @@ -495,30 +490,65 @@ core::CallTypedExprPtr asArraySortCall(

} // namespace

std::shared_ptr<exec::VectorFunction> makeArraySortLambdaFunction(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/,
bool ascending,
bool throwOnNestedNull) {
VELOX_CHECK_EQ(inputArgs.size(), 2);
rui-mo marked this conversation as resolved.
Show resolved Hide resolved
return std::make_shared<ArraySortLambdaFunction>(
ascending, throwOnNestedNull);
}

std::shared_ptr<exec::VectorFunction> makeArraySort(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/,
bool ascending,
bool nullsFirst,
bool throwOnNestedNull) {
const auto elementType = inputArgs.front().type->childAt(0);
if (elementType->isUnKnown()) {
return createTyped<TypeKind::UNKNOWN>(
inputArgs, ascending, nullsFirst, throwOnNestedNull);
}

return VELOX_DYNAMIC_TYPE_DISPATCH(
createTyped,
elementType->kind(),
inputArgs,
ascending,
nullsFirst,
throwOnNestedNull);
}

std::vector<std::shared_ptr<exec::FunctionSignature>> arraySortSignatures(
bool withComparator) {
return signatures(withComparator);
}

core::TypedExprPtr rewriteArraySortCall(
const std::string& prefix,
const core::TypedExprPtr& expr) {
const core::TypedExprPtr& expr,
const std::shared_ptr<SimpleComparisonChecker> checker) {
auto call = asArraySortCall(prefix, expr);
if (call == nullptr || call->inputs().size() != 2) {
return nullptr;
}

auto lambda =
dynamic_cast<const core::LambdaTypedExpr*>(call->inputs()[1].get());
VELOX_CHECK_NOT_NULL(lambda);

// Extract 'transform' from the comparison lambda:
// (x, y) -> if(func(x) < func(y),...) ===> x -> func(x).
if (lambda->signature()->size() != 2) {
return nullptr;
}

static const std::string kNotSupported =
"array_sort with comparator lambda that cannot be rewritten "
"into a transform is not supported: {}";

if (auto comparison =
functions::prestosql::isSimpleComparison(prefix, *lambda)) {
if (auto comparison = checker->isSimpleComparison(prefix, *lambda)) {
std::string name = comparison->isLessThen ? prefix + "array_sort"
: prefix + "array_sort_desc";

Expand All @@ -543,23 +573,12 @@ core::TypedExprPtr rewriteArraySortCall(
VELOX_USER_FAIL(kNotSupported, lambda->toString());
}

// Register function.
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_array_sort,
signatures(true),
createAsc);

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_array_sort_desc,
signatures(false),
createDesc);

// An internal function to canonicalize an array to allow for comparisons. Used
// in AggregationFuzzerTest. Details in
// https://github.com/facebookincubator/velox/issues/6999.
VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION(
udf_$internal$canonicalize,
internalCanonicalizeSignatures(),
createAscNoThrowOnNestedNull);
makeArraySortAscNoThrowOnNestedNull);

} // namespace facebook::velox::functions
Loading
Loading