Skip to content

Commit

Permalink
fix(rust, python)!: return f64 for rank when method="average" (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Aug 26, 2023
1 parent 9ab5cdd commit 0e507dd
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 26 deletions.
22 changes: 11 additions & 11 deletions crates/polars-core/src/chunked_array/ops/unique/rank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Optio
match s.len() {
1 => {
return match method {
Average => Series::new(s.name(), &[1.0f32]),
Average => Series::new(s.name(), &[1.0f64]),
_ => Series::new(s.name(), &[1 as IdxSize]),
};
},
0 => {
return match method {
Average => Float32Chunked::from_slice(s.name(), &[]).into_series(),
Average => Float64Chunked::from_slice(s.name(), &[]).into_series(),
_ => IdxCa::from_slice(s.name(), &[]).into_series(),
};
},
Expand Down Expand Up @@ -289,10 +289,10 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Optio
Average => {
// SAFETY: in bounds.
let a = unsafe { count.take_unchecked((&dense).into()) }
.cast(&DataType::Float32)
.cast(&DataType::Float64)
.unwrap();
let b = unsafe { count.take_unchecked((&(dense - 1)).into()) }
.cast(&DataType::Float32)
.cast(&DataType::Float64)
.unwrap()
+ 1.0;
(&a + &b) * 0.5
Expand Down Expand Up @@ -354,25 +354,25 @@ mod test {
assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]);

let out = rank(&s, RankMethod::Average, false, None)
.f32()?
.f64()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2.0f32, 4.0, 6.5, 4.0, 4.0, 6.5, 1.0]);
assert_eq!(out, &[2.0f64, 4.0, 6.5, 4.0, 4.0, 6.5, 1.0]);

let s = Series::new(
"a",
&[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)],
);

let out = rank(&s, RankMethod::Average, false, None)
.f32()?
.f64()?
.into_iter()
.collect::<Vec<_>>();

assert_eq!(
out,
&[
Some(2.0f32),
Some(2.0f64),
Some(3.5),
Some(5.0),
Some(3.5),
Expand Down Expand Up @@ -419,10 +419,10 @@ mod test {
fn test_rank_all_null() -> PolarsResult<()> {
let s = UInt32Chunked::new("", &[None, None, None]).into_series();
let out = rank(&s, RankMethod::Average, false, None)
.f32()?
.f64()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2.0f32, 2.0, 2.0]);
assert_eq!(out, &[2.0f64, 2.0, 2.0]);
let out = rank(&s, RankMethod::Dense, false, None)
.idx()?
.into_no_null_iter()
Expand All @@ -435,7 +435,7 @@ mod test {
fn test_rank_empty() {
let s = UInt32Chunked::from_slice("", &[]).into_series();
let out = rank(&s, RankMethod::Average, false, None);
assert_eq!(out.dtype(), &DataType::Float32);
assert_eq!(out.dtype(), &DataType::Float64);
let out = rank(&s, RankMethod::Max, false, None);
assert_eq!(out.dtype(), &IDX_DTYPE);
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,7 @@ fn test_single_ranked_group() -> PolarsResult<()> {
.collect()?;

let out = out.column("value")?.explode()?;
let out = out.f32()?;
let out = out.f64()?;
assert_eq!(
Vec::from(out),
&[Some(1.0), Some(2.0), Some(1.0), Some(2.0), Some(1.0)]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ impl Expr {
self.apply(
move |s| Ok(Some(s.rank(options, seed))),
GetOutput::map_field(move |fld| match options.method {
RankMethod::Average => Field::new(fld.name(), DataType::Float32),
RankMethod::Average => Field::new(fld.name(), DataType::Float64),
_ => Field::new(fld.name(), IDX_DTYPE),
}),
)
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7061,7 +7061,7 @@ def rank(
┌─────┐
│ a │
│ --- │
f32
f64
╞═════╡
│ 3.0 │
│ 4.5 │
Expand Down Expand Up @@ -7095,7 +7095,7 @@ def rank(
┌─────┬─────┬──────┐
│ a ┆ b ┆ rank │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ f32
│ i64 ┆ i64 ┆ f64
╞═════╪═════╪══════╡
│ 1 ┆ 6 ┆ 1.0 │
│ 1 ┆ 7 ┆ 2.0 │
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def eval(self, expr: Expr, *, parallel: bool = False) -> Expr:
┌─────┬─────┬────────────┐
│ a ┆ b ┆ rank │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ list[f32] │
│ i64 ┆ i64 ┆ list[f64] │
╞═════╪═════╪════════════╡
│ 1 ┆ 4 ┆ [1.0, 2.0] │
│ 8 ┆ 5 ┆ [2.0, 1.0] │
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/functions/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def element() -> Expr:
┌─────┬─────┬────────────┐
│ a ┆ b ┆ rank │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ list[f32] │
│ i64 ┆ i64 ┆ list[f64] │
╞═════╪═════╪════════════╡
│ 1 ┆ 4 ┆ [1.0, 2.0] │
│ 8 ┆ 5 ┆ [2.0, 1.0] │
Expand Down Expand Up @@ -907,7 +907,7 @@ def corr(
┌─────┐
│ a │
│ --- │
f32
f64
╞═════╡
│ 0.5 │
└─────┘
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def eval(self, expr: Expr, *, parallel: bool = False) -> Series:
┌─────┬─────┬────────────┐
│ a ┆ b ┆ rank │
│ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ list[f32] │
│ i64 ┆ i64 ┆ list[f64] │
╞═════╪═════╪════════════╡
│ 1 ┆ 4 ┆ [1.0, 2.0] │
│ 8 ┆ 5 ┆ [2.0, 1.0] │
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5773,7 +5773,7 @@ def rank(
>>> s = pl.Series("a", [3, 6, 1, 1, 6])
>>> s.rank()
shape: (5,)
Series: 'a' [f32]
Series: 'a' [f64]
[
3.0
4.5
Expand Down
8 changes: 4 additions & 4 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ def test_list_eval_dtype_inference() -> None:
.list.first()
]
).to_series().to_list() == [
0.3333333432674408,
0.6666666865348816,
0.6666666865348816,
0.3333333432674408,
0.3333333333333333,
0.6666666666666666,
0.6666666666666666,
0.3333333333333333,
]


Expand Down
3 changes: 3 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,9 @@ def test_rank() -> None:
pl.Series("a", [3, 2, 1, 2, 2, 1, 4], dtype=UInt32),
)

assert s.rank(method="average").dtype == pl.Float64
assert s.rank(method="max").dtype == pl.get_index_type()


def test_diff() -> None:
s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0])
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,22 @@ def test_arr_contains() -> None:
}


def test_rank() -> None:
df = pl.DataFrame(
{
"a": [1, 1, 2, 2, 3],
}
)

s = df.select(pl.col("a").rank(method="average").alias("b")).to_series()
assert s.to_list() == [1.5, 1.5, 3.5, 3.5, 5.0]
assert s.dtype == pl.Float64

s = df.select(pl.col("a").rank(method="max").alias("b")).to_series()
assert s.to_list() == [2, 2, 4, 4, 5]
assert s.dtype == pl.get_index_type()


def test_rank_so_4109() -> None:
# also tests ranks null behavior
df = pl.from_dict(
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,9 +1016,9 @@ def test_spearman_corr_ties() -> None:
)
expected = pl.DataFrame(
[
pl.Series("a1", [-0.19048483669757843], dtype=pl.Float32),
pl.Series("a1", [-0.19048482943986483], dtype=pl.Float64),
pl.Series("a2", [-0.17223653586587362], dtype=pl.Float64),
pl.Series("a3", [-0.19048483669757843], dtype=pl.Float32),
pl.Series("a3", [-0.19048482943986483], dtype=pl.Float64),
]
)
assert_frame_equal(result, expected)
Expand Down

0 comments on commit 0e507dd

Please sign in to comment.