diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 61335de6ac09a..53ceed1b0893e 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -270,6 +270,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kRankQuantileOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType)); diff --git a/python/pyarrow/_acero.pyx b/python/pyarrow/_acero.pyx index 9e8cbd65be224..d49945ed70009 100644 --- a/python/pyarrow/_acero.pyx +++ b/python/pyarrow/_acero.pyx @@ -30,7 +30,7 @@ from pyarrow.lib cimport (Table, pyarrow_unwrap_table, pyarrow_wrap_table, from pyarrow.lib import frombytes, tobytes from pyarrow._compute cimport ( Expression, FunctionOptions, _ensure_field_ref, _true, - unwrap_null_placement, unwrap_sort_order + unwrap_null_placement, unwrap_sort_keys ) @@ -234,17 +234,10 @@ class AggregateNodeOptions(_AggregateNodeOptions): cdef class _OrderByNodeOptions(ExecNodeOptions): def _set_options(self, sort_keys, null_placement): - cdef: - vector[CSortKey] c_sort_keys - - for name, order in sort_keys: - c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) - ) - self.wrapped.reset( new COrderByNodeOptions( - COrdering(c_sort_keys, unwrap_null_placement(null_placement)) + COrdering(unwrap_sort_keys(sort_keys, allow_str=False), + unwrap_null_placement(null_placement)) ) ) diff --git a/python/pyarrow/_compute.pxd b/python/pyarrow/_compute.pxd index 29b37da3ac4ef..648c1e0e2e5b3 100644 --- a/python/pyarrow/_compute.pxd +++ b/python/pyarrow/_compute.pxd @@ -65,6 +65,8 @@ cdef CExpression _true cdef CFieldRef _ensure_field_ref(value) except * +cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=*) except * + cdef CSortOrder unwrap_sort_order(order) except * cdef CNullPlacement unwrap_null_placement(null_placement) except * diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 658f6b6cac4b5..d23286dcdd02e 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -74,6 +74,20 @@ def _forbid_instantiation(klass, subclasses_instead=True): raise TypeError(msg) +cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=True): + cdef vector[CSortKey] c_sort_keys + if allow_str and isinstance(sort_keys, str): + c_sort_keys.push_back( + CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys)) + ) + else: + for name, order in sort_keys: + c_sort_keys.push_back( + CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + ) + return c_sort_keys + + cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func): """ Wrap a C++ scalar Function in a ScalarFunction object. @@ -2093,13 +2107,9 @@ class ArraySortOptions(_ArraySortOptions): cdef class _SortOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): - cdef vector[CSortKey] c_sort_keys - for name, order in sort_keys: - c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) - ) self.wrapped.reset(new CSortOptions( - c_sort_keys, unwrap_null_placement(null_placement))) + unwrap_sort_keys(sort_keys, allow_str=False), + unwrap_null_placement(null_placement))) class SortOptions(_SortOptions): @@ -2125,12 +2135,7 @@ class SortOptions(_SortOptions): cdef class _SelectKOptions(FunctionOptions): def _set_options(self, k, sort_keys): - cdef vector[CSortKey] c_sort_keys - for name, order in sort_keys: - c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) - ) - self.wrapped.reset(new CSelectKOptions(k, c_sort_keys)) + self.wrapped.reset(new CSelectKOptions(k, unwrap_sort_keys(sort_keys, allow_str=False))) class SelectKOptions(_SelectKOptions): @@ -2317,19 +2322,9 @@ cdef class _RankOptions(FunctionOptions): } def _set_options(self, sort_keys, null_placement, tiebreaker): - cdef vector[CSortKey] c_sort_keys - if isinstance(sort_keys, str): - c_sort_keys.push_back( - CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys)) - ) - else: - for name, order in sort_keys: - c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) - ) try: self.wrapped.reset( - new CRankOptions(c_sort_keys, + new CRankOptions(unwrap_sort_keys(sort_keys), unwrap_null_placement(null_placement), self._tiebreaker_map[tiebreaker]) ) @@ -2370,6 +2365,37 @@ class RankOptions(_RankOptions): self._set_options(sort_keys, null_placement, tiebreaker) +cdef class _RankQuantileOptions(FunctionOptions): + + def _set_options(self, sort_keys, null_placement): + self.wrapped.reset( + new CRankQuantileOptions(unwrap_sort_keys(sort_keys), + unwrap_null_placement(null_placement)) + ) + + +class RankQuantileOptions(_RankQuantileOptions): + """ + Options for the `rank_quantile` function. + + Parameters + ---------- + sort_keys : sequence of (name, order) tuples or str, default "ascending" + Names of field/column keys to sort the input on, + along with the order each field/column is sorted in. + Accepted values for `order` are "ascending", "descending". + The field name can be a string column name or expression. + Alternatively, one can simply pass "ascending" or "descending" as a string + if the input is array-like. + null_placement : str, default "at_end" + Where nulls in input should be sorted. + Accepted values are "at_start", "at_end". + """ + + def __init__(self, sort_keys="ascending", *, null_placement="at_end"): + self._set_options(sort_keys, null_placement) + + cdef class Expression(_Weakrefable): """ A logical expression to be evaluated against some input. diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 426ecae31c039..5348336235118 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -56,6 +56,7 @@ QuantileOptions, RandomOptions, RankOptions, + RankQuantileOptions, ReplaceSliceOptions, ReplaceSubstringOptions, RoundBinaryOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 021c1c782c6e5..88ad77e56f8fc 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2788,6 +2788,12 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CNullPlacement null_placement CRankOptionsTiebreaker tiebreaker + cdef cppclass CRankQuantileOptions \ + "arrow::compute::RankQuantileOptions"(CFunctionOptions): + CRankQuantileOptions(vector[CSortKey] sort_keys, CNullPlacement) + vector[CSortKey] sort_keys + CNullPlacement null_placement + cdef enum DatumType" arrow::Datum::type": DatumType_NONE" arrow::Datum::NONE" DatumType_SCALAR" arrow::Datum::SCALAR" diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 6f28205a18e13..e6d3b80bda953 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -172,6 +172,8 @@ def test_option_class_equality(request): pc.RandomOptions(), pc.RankOptions(sort_keys="ascending", null_placement="at_start", tiebreaker="max"), + pc.RankQuantileOptions(sort_keys="ascending", + null_placement="at_start"), pc.ReplaceSliceOptions(0, 1, "a"), pc.ReplaceSubstringOptions("a", "b"), pc.RoundOptions(2, "towards_infinity"), @@ -3360,6 +3362,36 @@ def test_rank_options(): tiebreaker="NonExisting") +def test_rank_quantile_options(): + arr = pa.array([None, 1, None, 2, None]) + expected = pa.array([0.7, 0.1, 0.7, 0.3, 0.7], type=pa.float64()) + + # Ensure rank_quantile can be called without specifying options + result = pc.rank_quantile(arr) + assert result.equals(expected) + + # Ensure default RankOptions + result = pc.rank_quantile(arr, options=pc.RankQuantileOptions()) + assert result.equals(expected) + + # Ensure sort_keys tuple usage + result = pc.rank_quantile(arr, options=pc.RankQuantileOptions( + sort_keys=[("b", "ascending")]) + ) + assert result.equals(expected) + + result = pc.rank_quantile(arr, null_placement="at_start") + expected_at_start = pa.array([0.3, 0.7, 0.3, 0.9, 0.3], type=pa.float64()) + assert result.equals(expected_at_start) + + result = pc.rank_quantile(arr, sort_keys="descending") + expected_descending = pa.array([0.7, 0.3, 0.7, 0.1, 0.7], type=pa.float64()) + assert result.equals(expected_descending) + + with pytest.raises(ValueError, match="not a valid sort order"): + pc.rank_quantile(arr, sort_keys="XXX") + + def create_sample_expressions(): # We need a schema for substrait conversion schema = pa.schema([pa.field("i64", pa.int64()), pa.field(