Skip to content

Commit

Permalink
fix: SQL COUNT(DISTINCT x) should not include NULL values
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Jul 29, 2024
1 parent 9c29683 commit 7a7f455
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 9 deletions.
4 changes: 3 additions & 1 deletion crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Sub;

use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions};
use polars_core::export::regex;
use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult, Schema, TimeUnit};
Expand Down Expand Up @@ -1573,7 +1575,7 @@ impl SQLFunctionVisitor<'_> {
(true, [FunctionArgExpr::Expr(sql_expr)]) => {
let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?;
let expr = self.apply_window_spec(expr, &self.func.over)?;
Ok(expr.n_unique())
Ok(expr.clone().n_unique().sub(expr.null_count().gt(lit(0))))
},
_ => self.not_supported_error(),
}
Expand Down
18 changes: 10 additions & 8 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3154,22 +3154,24 @@ def n_unique(self) -> Expr:
Notes
-----
`null` is considered to be a unique value for the purposes of this operation.
To get the unuique count excluding nulls, use `col.drop_nulls().n_unique()`.
Examples
--------
>>> df = pl.DataFrame({"x": [1, 1, 2, 2, 3], "y": [1, 1, 1, None, None]})
>>> df.select(
... x_unique=pl.col("x").n_unique(),
... y_unique=pl.col("y").n_unique(),
... y_unique_ex_null=pl.col("y").drop_nulls().n_unique(),
... )
shape: (1, 2)
┌──────────┬──────────┐
│ x_unique ┆ y_unique │
│ --- ┆ --- │
│ u32 ┆ u32 │
╞══════════╪══════════╡
│ 3 ┆ 2 │
└──────────┴──────────┘
shape: (1, 3)
┌──────────┬──────────┬──────────────────
│ x_unique ┆ y_unique ┆ y_unique_ex_null
│ --- ┆ --- ┆ ---
│ u32 ┆ u32 ┆ u32
╞══════════╪══════════╪══════════════════
│ 3 ┆ 2 ┆ 1
└──────────┴──────────┴──────────────────
"""
return self._from_pyexpr(self._pyexpr.n_unique())

Expand Down
54 changes: 54 additions & 0 deletions py-polars/tests/unit/sql/test_miscellaneous.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,60 @@ def test_any_all() -> None:
}


def test_count() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": [1, 1, 22, 22, 333],
"c": [1, 1, None, None, 2],
}
)
res = df.sql(
"""
SELECT
-- count
COUNT(a) AS count_a,
COUNT(b) AS count_b,
COUNT(c) AS count_c,
COUNT(*) AS count_star,
COUNT(NULL) AS count_null,
-- count distinct
COUNT(DISTINCT a) AS count_unique_a,
COUNT(DISTINCT b) AS count_unique_b,
COUNT(DISTINCT c) AS count_unique_c,
COUNT(DISTINCT NULL) AS count_unique_null,
FROM self
""",
)
assert res.to_dict(as_series=False) == {
"count_a": [5],
"count_b": [5],
"count_c": [3],
"count_star": [5],
"count_null": [0],
"count_unique_a": [5],
"count_unique_b": [3],
"count_unique_c": [2],
"count_unique_null": [0],
}

df = pl.DataFrame({"x": [None, None, None]})
res = df.sql(
"""
SELECT
COUNT(x) AS count_x,
COUNT(*) AS count_star,
COUNT(DISTINCT x) AS count_unique_x
FROM self
"""
)
assert res.to_dict(as_series=False) == {
"count_x": [0],
"count_star": [3],
"count_unique_x": [0],
}


def test_distinct() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit 7a7f455

Please sign in to comment.