Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
boneanxs committed Jun 20, 2024
1 parent 6c91c8f commit 901aae4
Show file tree
Hide file tree
Showing 12 changed files with 230 additions and 388 deletions.
93 changes: 82 additions & 11 deletions velox/functions/lib/ArraySort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include "velox/expression/EvalCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/ArraySort.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/type/FloatingPointUtil.h"
Expand Down Expand Up @@ -99,7 +99,13 @@ void applyComplexType(
bool throwOnNestedNull) {
auto inputElements = inputArray->elements();
auto indices = sortElements(
rows, *inputArray, *inputElements, ascending, nullsFirst, context, throwOnNestedNull);
rows,
*inputArray,
*inputElements,
ascending,
nullsFirst,
context,
throwOnNestedNull);
resultElements = BaseVector::transpose(indices, std::move(inputElements));
}

Expand Down Expand Up @@ -161,14 +167,15 @@ void applyScalarType(
}
}
} else {
// Move nulls to end of array.
for (vector_size_t i = size - 1; i >= 0; --i) {
if (flatResults->isNullAt(offset + i)) {
swapWithNull<T>(flatResults, offset + size - numNulls - 1, offset + i);
++numNulls;
// Move nulls to end of array.
for (vector_size_t i = size - 1; i >= 0; --i) {
if (flatResults->isNullAt(offset + i)) {
swapWithNull<T>(
flatResults, offset + size - numNulls - 1, offset + i);
++numNulls;
}
}
}
}
// Exclude null values while sorting.
const auto startRow = offset + (nullsFirst ? numNulls : 0);
const auto endRow = startRow + size - numNulls;
Expand Down Expand Up @@ -233,8 +240,13 @@ class ArraySortFunction : public exec::VectorFunction {
/// and 'offsets' vectors that control where output arrays start and end
/// remain the same in the output ArrayVector.

explicit ArraySortFunction(bool ascending, bool nullsFirst, bool throwOnNestedNull)
: ascending_{ascending}, nullsFirst_{nullsFirst}, throwOnNestedNull_(throwOnNestedNull) {}
explicit ArraySortFunction(
bool ascending,
bool nullsFirst,
bool throwOnNestedNull)
: ascending_{ascending},
nullsFirst_{nullsFirst},
throwOnNestedNull_(throwOnNestedNull) {}

// Execute function.
void apply(
Expand Down Expand Up @@ -480,6 +492,17 @@ std::shared_ptr<exec::VectorFunction> createAscNoThrowOnNestedNull(
return create(inputArgs, true, false, false);
}

core::CallTypedExprPtr asArraySortCall(
const std::string& prefix,
const core::TypedExprPtr& expr) {
if (auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(expr)) {
if (call->name() == prefix + "array_sort") {
return call;
}
}
return nullptr;
}

} // namespace

std::shared_ptr<exec::VectorFunction> makeArraySortLambdaFunction(
Expand All @@ -490,7 +513,7 @@ std::shared_ptr<exec::VectorFunction> makeArraySortLambdaFunction(
bool throwOnNestedNull) {
VELOX_CHECK_EQ(inputArgs.size(), 2);
return std::make_shared<ArraySortLambdaFunction>(
ascending, throwOnNestedNull);
ascending, throwOnNestedNull);
}

std::shared_ptr<exec::VectorFunction> makeArraySort(
Expand All @@ -508,6 +531,54 @@ std::vector<std::shared_ptr<exec::FunctionSignature>> arraySortSignatures(
return signatures(withComparator);
}

core::TypedExprPtr rewriteArraySortCall(
const std::string& prefix,
const core::TypedExprPtr& expr,
const std::shared_ptr<SimpleComparisonChecker> checker) {
auto call = asArraySortCall(prefix, expr);
if (call == nullptr || call->inputs().size() != 2) {
return nullptr;
}

auto lambda =
dynamic_cast<const core::LambdaTypedExpr*>(call->inputs()[1].get());
VELOX_CHECK_NOT_NULL(lambda);

// Extract 'transform' from the comparison lambda:
// (x, y) -> if(func(x) < func(y),...) ===> x -> func(x).
if (lambda->signature()->size() != 2) {
return nullptr;
}

static const std::string kNotSupported =
"array_sort with comparator lambda that cannot be rewritten "
"into a transform is not supported: {}";

if (auto comparison = checker->isSimpleComparison(prefix, *lambda)) {
std::string name = comparison->isLessThen ? prefix + "array_sort"
: prefix + "array_sort_desc";

if (!comparison->expr->type()->isOrderable()) {
VELOX_USER_FAIL(kNotSupported, lambda->toString())
}

auto rewritten = std::make_shared<core::CallTypedExpr>(
call->type(),
std::vector<core::TypedExprPtr>{
call->inputs()[0],
std::make_shared<core::LambdaTypedExpr>(
ROW({lambda->signature()->nameOf(0)},
{lambda->signature()->childAt(0)}),
comparison->expr),
},
name);

return rewritten;
}

VELOX_USER_FAIL(kNotSupported, lambda->toString())
}

// An internal function to canonicalize an array to allow for comparisons. Used
// in AggregationFuzzerTest. Details in
// https://github.com/facebookincubator/velox/issues/6999.
Expand Down
16 changes: 16 additions & 0 deletions velox/functions/lib/ArraySort.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/SimpleComparisonMatcher.h"

namespace facebook::velox::functions {

Expand Down Expand Up @@ -55,4 +56,19 @@ std::shared_ptr<exec::VectorFunction> makeArraySortLambdaFunction(
std::vector<std::shared_ptr<exec::FunctionSignature>> arraySortSignatures(
bool withComparator);

/// Analyzes array_sort(array, lambda) call to determine whether it can be
/// re-written into a simpler call that specifies sort-by expression.
///
/// For example, rewrites
/// array_sort(a, (x, y) -> if(length(x) < length(y), -1, if(length(x) >
/// length(y), 1, 0))
/// into
/// array_sort(a, x -> length(x))
///
/// Returns new expression or nullptr if rewrite is not possible.
core::TypedExprPtr rewriteArraySortCall(
const std::string& prefix,
const core::TypedExprPtr& expr,
const std::shared_ptr<SimpleComparisonChecker> checker);

} // namespace facebook::velox::functions
77 changes: 56 additions & 21 deletions velox/functions/lib/SimpleComparisonMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,15 @@ class ComparisonMatcher : public Matcher {
VELOX_CHECK_EQ(2, inputMatchers_.size());
}

virtual bool exprNameMatch(const std::string& name) {
return name == prefix_ + "eq" || name == prefix_ + "lt" ||
name == prefix_ + "gt";
}

bool match(const core::TypedExprPtr& expr) override {
if (auto call = dynamic_cast<const core::CallTypedExpr*>(expr.get())) {
const auto& name = call->name();
if (name == prefix_ + "eq" || name == prefix_ + "lt" ||
name == prefix_ + "gt") {
if (exprNameMatch(name)) {
if (allMatch(call->inputs(), inputMatchers_)) {
*op_ = name;
return true;
Expand All @@ -91,8 +95,10 @@ class ComparisonMatcher : public Matcher {
return false;
}

private:
protected:
const std::string prefix_;

private:
std::vector<MatcherPtr> inputMatchers_;
std::string* op_;
};
Expand Down Expand Up @@ -190,36 +196,65 @@ using ComparisonConstantMatcherPtr = std::shared_ptr<ComparisonConstantMatcher>;

class SimpleComparisonChecker {
protected:
virtual MatcherPtr ifelse(
MatcherPtr ifelse(
const MatcherPtr& condition,
const MatcherPtr& thenClause,
const MatcherPtr& elseClause) = 0;

virtual MatcherPtr comparison(
const std::string& prefix,
const MatcherPtr& left,
const MatcherPtr& right,
std::string* op) = 0;
const MatcherPtr& elseClause) {
return std::make_shared<IfMatcher>(
std::vector<MatcherPtr>{condition, thenClause, elseClause});
}

virtual MatcherPtr anySingleInput(
MatcherPtr anySingleInput(
core::TypedExprPtr* expr,
core::FieldAccessTypedExprPtr* input) = 0;
core::FieldAccessTypedExprPtr* input) {
return std::make_shared<AnySingleInputMatcher>(expr, input);
}

virtual MatcherPtr comparisonConstant(int64_t* value) = 0;
MatcherPtr comparisonConstant(int64_t* value) {
return std::make_shared<ComparisonConstantMatcher>(value);
}

virtual std::string invert(
const std::string& prefix,
const std::string& op) = 0;
std::string invert(const std::string& prefix, const std::string& op) {
return op == ltName(prefix) ? gtName(prefix) : ltName(prefix);
}

/// Returns true for a < b -> -1.
virtual bool isLessThen(
bool isLessThen(
const std::string& prefix,
const std::string& operation,
const core::FieldAccessTypedExprPtr& left,
int64_t result,
const std::string& inputLeft) = 0;

virtual std::string eqName(const std::string& prefix) = 0;
const std::string& inputLeft) {
std::string op =
(left->name() == inputLeft) ? operation : invert(prefix, operation);

if (op == ltName(prefix)) {
return result < 0;
}

return result > 0;
}

virtual MatcherPtr comparison(
const std::string& prefix,
const MatcherPtr& left,
const MatcherPtr& right,
std::string* op) {
return std::make_shared<ComparisonMatcher>(
prefix, std::vector<MatcherPtr>{left, right}, op);
}

virtual std::string eqName(const std::string& prefix) {
return prefix + "eq";
}

virtual std::string ltName(const std::string& prefix) {
return prefix + "lt";
}

virtual std::string gtName(const std::string& prefix) {
return prefix + "gt";
}

public:
virtual ~SimpleComparisonChecker() = default;
Expand Down
76 changes: 1 addition & 75 deletions velox/functions/prestosql/ArraySort.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,10 @@

#include "velox/core/Expressions.h"
#include "velox/functions/lib/ArraySort.h"
#include "velox/functions/prestosql/SimpleComparisonChecker.h"
#include "velox/functions/lib/SimpleComparisonMatcher.h"

namespace facebook::velox::functions {

namespace {
core::CallTypedExprPtr asArraySortCall(
const std::string& prefix,
const core::TypedExprPtr& expr) {
if (auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(expr)) {
if (call->name() == prefix + "array_sort") {
return call;
}
}
return nullptr;
}

} // namespace

std::shared_ptr<exec::VectorFunction> makeArraySortAsc(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
Expand All @@ -59,64 +45,4 @@ std::shared_ptr<exec::VectorFunction> makeArraySortDesc(
return makeArraySort(name, inputArgs, config, false, false, true);
}

/// Analyzes array_sort(array, lambda) call to determine whether it can be
/// re-written into a simpler call that specifies sort-by expression.
///
/// For example, rewrites
/// array_sort(a, (x, y) -> if(length(x) < length(y), -1, if(length(x) >
/// length(y), 1, 0))
/// into
/// array_sort(a, x -> length(x))
///
/// Returns new expression or nullptr if rewrite is not possible.
core::TypedExprPtr rewriteArraySortCall(
const std::string& prefix,
const core::TypedExprPtr& expr) {
auto call = asArraySortCall(prefix, expr);
if (call == nullptr || call->inputs().size() != 2) {
return nullptr;
}

auto lambda =
dynamic_cast<const core::LambdaTypedExpr*>(call->inputs()[1].get());
VELOX_CHECK_NOT_NULL(lambda);

// Extract 'transform' from the comparison lambda:
// (x, y) -> if(func(x) < func(y),...) ===> x -> func(x).
if (lambda->signature()->size() != 2) {
return nullptr;
}

static const std::string kNotSupported =
"array_sort with comparator lambda that cannot be rewritten "
"into a transform is not supported: {}";

auto checker =
std::make_unique<functions::prestosql::PrestoSimpleComparisonChecker>();

if (auto comparison = checker->isSimpleComparison(prefix, *lambda)) {
std::string name = comparison->isLessThen ? prefix + "array_sort"
: prefix + "array_sort_desc";

if (!comparison->expr->type()->isOrderable()) {
VELOX_USER_FAIL(kNotSupported, lambda->toString())
}

auto rewritten = std::make_shared<core::CallTypedExpr>(
call->type(),
std::vector<core::TypedExprPtr>{
call->inputs()[0],
std::make_shared<core::LambdaTypedExpr>(
ROW({lambda->signature()->nameOf(0)},
{lambda->signature()->childAt(0)}),
comparison->expr),
},
name);

return rewritten;
}

VELOX_USER_FAIL(kNotSupported, lambda->toString())
}

} // namespace facebook::velox::functions
Loading

0 comments on commit 901aae4

Please sign in to comment.