Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-45380: [Python] Expose RankQuantileOptions to Python #45392

Merged
merged 7 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kRankQuantileOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType));
Expand Down
13 changes: 3 additions & 10 deletions python/pyarrow/_acero.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ from pyarrow.lib cimport (Table, pyarrow_unwrap_table, pyarrow_wrap_table,
from pyarrow.lib import frombytes, tobytes
from pyarrow._compute cimport (
Expression, FunctionOptions, _ensure_field_ref, _true,
unwrap_null_placement, unwrap_sort_order
unwrap_null_placement, unwrap_sort_keys
)


Expand Down Expand Up @@ -234,17 +234,10 @@ class AggregateNodeOptions(_AggregateNodeOptions):
cdef class _OrderByNodeOptions(ExecNodeOptions):

def _set_options(self, sort_keys, null_placement):
cdef:
vector[CSortKey] c_sort_keys

for name, order in sort_keys:
c_sort_keys.push_back(
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
)

self.wrapped.reset(
new COrderByNodeOptions(
COrdering(c_sort_keys, unwrap_null_placement(null_placement))
COrdering(unwrap_sort_keys(sort_keys, allow_str=False),
unwrap_null_placement(null_placement))
)
)

Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/_compute.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ cdef CExpression _true

cdef CFieldRef _ensure_field_ref(value) except *

cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=*) except *

cdef CSortOrder unwrap_sort_order(order) except *

cdef CNullPlacement unwrap_null_placement(null_placement) except *
72 changes: 49 additions & 23 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ def _forbid_instantiation(klass, subclasses_instead=True):
raise TypeError(msg)


cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=True):
cdef vector[CSortKey] c_sort_keys
if allow_str and isinstance(sort_keys, str):
c_sort_keys.push_back(
CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys))
)
else:
for name, order in sort_keys:
c_sort_keys.push_back(
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
)
return c_sort_keys


cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func):
"""
Wrap a C++ scalar Function in a ScalarFunction object.
Expand Down Expand Up @@ -2093,13 +2107,9 @@ class ArraySortOptions(_ArraySortOptions):

cdef class _SortOptions(FunctionOptions):
def _set_options(self, sort_keys, null_placement):
cdef vector[CSortKey] c_sort_keys
for name, order in sort_keys:
c_sort_keys.push_back(
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
)
self.wrapped.reset(new CSortOptions(
c_sort_keys, unwrap_null_placement(null_placement)))
unwrap_sort_keys(sort_keys, allow_str=False),
unwrap_null_placement(null_placement)))


class SortOptions(_SortOptions):
Expand All @@ -2125,12 +2135,7 @@ class SortOptions(_SortOptions):

cdef class _SelectKOptions(FunctionOptions):
def _set_options(self, k, sort_keys):
cdef vector[CSortKey] c_sort_keys
for name, order in sort_keys:
c_sort_keys.push_back(
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
)
self.wrapped.reset(new CSelectKOptions(k, c_sort_keys))
self.wrapped.reset(new CSelectKOptions(k, unwrap_sort_keys(sort_keys, allow_str=False)))


class SelectKOptions(_SelectKOptions):
Expand Down Expand Up @@ -2317,19 +2322,9 @@ cdef class _RankOptions(FunctionOptions):
}

def _set_options(self, sort_keys, null_placement, tiebreaker):
cdef vector[CSortKey] c_sort_keys
if isinstance(sort_keys, str):
c_sort_keys.push_back(
CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys))
)
else:
for name, order in sort_keys:
c_sort_keys.push_back(
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
)
try:
self.wrapped.reset(
new CRankOptions(c_sort_keys,
new CRankOptions(unwrap_sort_keys(sort_keys),
unwrap_null_placement(null_placement),
self._tiebreaker_map[tiebreaker])
)
Expand Down Expand Up @@ -2370,6 +2365,37 @@ class RankOptions(_RankOptions):
self._set_options(sort_keys, null_placement, tiebreaker)


cdef class _RankQuantileOptions(FunctionOptions):

def _set_options(self, sort_keys, null_placement):
self.wrapped.reset(
new CRankQuantileOptions(unwrap_sort_keys(sort_keys),
unwrap_null_placement(null_placement))
)


class RankQuantileOptions(_RankQuantileOptions):
"""
Options for the `rank_quantile` function.

Parameters
----------
sort_keys : sequence of (name, order) tuples or str, default "ascending"
Names of field/column keys to sort the input on,
along with the order each field/column is sorted in.
Accepted values for `order` are "ascending", "descending".
The field name can be a string column name or expression.
Alternatively, one can simply pass "ascending" or "descending" as a string
if the input is array-like.
null_placement : str, default "at_end"
Where nulls in input should be sorted.
Accepted values are "at_start", "at_end".
"""

def __init__(self, sort_keys="ascending", *, null_placement="at_end"):
self._set_options(sort_keys, null_placement)


cdef class Expression(_Weakrefable):
"""
A logical expression to be evaluated against some input.
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
QuantileOptions,
RandomOptions,
RankOptions,
RankQuantileOptions,
ReplaceSliceOptions,
ReplaceSubstringOptions,
RoundBinaryOptions,
Expand Down
6 changes: 6 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2788,6 +2788,12 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
CNullPlacement null_placement
CRankOptionsTiebreaker tiebreaker

cdef cppclass CRankQuantileOptions \
"arrow::compute::RankQuantileOptions"(CFunctionOptions):
CRankQuantileOptions(vector[CSortKey] sort_keys, CNullPlacement)
vector[CSortKey] sort_keys
CNullPlacement null_placement

cdef enum DatumType" arrow::Datum::type":
DatumType_NONE" arrow::Datum::NONE"
DatumType_SCALAR" arrow::Datum::SCALAR"
Expand Down
32 changes: 32 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def test_option_class_equality(request):
pc.RandomOptions(),
pc.RankOptions(sort_keys="ascending",
null_placement="at_start", tiebreaker="max"),
pc.RankQuantileOptions(sort_keys="ascending",
null_placement="at_start"),
pc.ReplaceSliceOptions(0, 1, "a"),
pc.ReplaceSubstringOptions("a", "b"),
pc.RoundOptions(2, "towards_infinity"),
Expand Down Expand Up @@ -3360,6 +3362,36 @@ def test_rank_options():
tiebreaker="NonExisting")


def test_rank_quantile_options():
arr = pa.array([None, 1, None, 2, None])
expected = pa.array([0.7, 0.1, 0.7, 0.3, 0.7], type=pa.float64())

# Ensure rank_quantile can be called without specifying options
result = pc.rank_quantile(arr)
assert result.equals(expected)

# Ensure default RankOptions
result = pc.rank_quantile(arr, options=pc.RankQuantileOptions())
assert result.equals(expected)

# Ensure sort_keys tuple usage
result = pc.rank_quantile(arr, options=pc.RankQuantileOptions(
sort_keys=[("b", "ascending")])
)
assert result.equals(expected)

result = pc.rank_quantile(arr, null_placement="at_start")
expected_at_start = pa.array([0.3, 0.7, 0.3, 0.9, 0.3], type=pa.float64())
assert result.equals(expected_at_start)

result = pc.rank_quantile(arr, sort_keys="descending")
expected_descending = pa.array([0.7, 0.3, 0.7, 0.1, 0.7], type=pa.float64())
assert result.equals(expected_descending)

with pytest.raises(ValueError, match="not a valid sort order"):
pc.rank_quantile(arr, sort_keys="XXX")


def create_sample_expressions():
# We need a schema for substrait conversion
schema = pa.schema([pa.field("i64", pa.int64()), pa.field(
Expand Down
Loading