From 7a7f4552901d88d9a7362774c004c11de8a60fd4 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 29 Jul 2024 23:01:35 +0400 Subject: [PATCH] fix: SQL `COUNT(DISTINCT x)` should not include NULL values --- crates/polars-sql/src/functions.rs | 4 +- py-polars/polars/expr/expr.py | 18 ++++--- .../tests/unit/sql/test_miscellaneous.py | 54 +++++++++++++++++++ 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 80af7e15f616..0124a5409f7d 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1,3 +1,5 @@ +use std::ops::Sub; + use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions}; use polars_core::export::regex; use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult, Schema, TimeUnit}; @@ -1573,7 +1575,7 @@ impl SQLFunctionVisitor<'_> { (true, [FunctionArgExpr::Expr(sql_expr)]) => { let expr = parse_sql_expr(sql_expr, self.ctx, self.active_schema)?; let expr = self.apply_window_spec(expr, &self.func.over)?; - Ok(expr.n_unique()) + Ok(expr.clone().n_unique().sub(expr.null_count().gt(lit(0)))) }, _ => self.not_supported_error(), } diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index ec65f40af0d7..cbcbf1b65d43 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -3154,6 +3154,7 @@ def n_unique(self) -> Expr: Notes ----- `null` is considered to be a unique value for the purposes of this operation. + To get the unuique count excluding nulls, use `col.drop_nulls().n_unique()`. Examples -------- @@ -3161,15 +3162,16 @@ def n_unique(self) -> Expr: >>> df.select( ... x_unique=pl.col("x").n_unique(), ... y_unique=pl.col("y").n_unique(), + ... y_unique_ex_null=pl.col("y").drop_nulls().n_unique(), ... ) - shape: (1, 2) - ┌──────────┬──────────┐ - │ x_unique ┆ y_unique │ - │ --- ┆ --- │ - │ u32 ┆ u32 │ - ╞══════════╪══════════╡ - │ 3 ┆ 2 │ - └──────────┴──────────┘ + shape: (1, 3) + ┌──────────┬──────────┬──────────────────┐ + │ x_unique ┆ y_unique ┆ y_unique_ex_null │ + │ --- ┆ --- ┆ --- │ + │ u32 ┆ u32 ┆ u32 │ + ╞══════════╪══════════╪══════════════════╡ + │ 3 ┆ 2 ┆ 1 │ + └──────────┴──────────┴──────────────────┘ """ return self._from_pyexpr(self._pyexpr.n_unique()) diff --git a/py-polars/tests/unit/sql/test_miscellaneous.py b/py-polars/tests/unit/sql/test_miscellaneous.py index f3979219c813..1ecd08e01e25 100644 --- a/py-polars/tests/unit/sql/test_miscellaneous.py +++ b/py-polars/tests/unit/sql/test_miscellaneous.py @@ -53,6 +53,60 @@ def test_any_all() -> None: } +def test_count() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5], + "b": [1, 1, 22, 22, 333], + "c": [1, 1, None, None, 2], + } + ) + res = df.sql( + """ + SELECT + -- count + COUNT(a) AS count_a, + COUNT(b) AS count_b, + COUNT(c) AS count_c, + COUNT(*) AS count_star, + COUNT(NULL) AS count_null, + -- count distinct + COUNT(DISTINCT a) AS count_unique_a, + COUNT(DISTINCT b) AS count_unique_b, + COUNT(DISTINCT c) AS count_unique_c, + COUNT(DISTINCT NULL) AS count_unique_null, + FROM self + """, + ) + assert res.to_dict(as_series=False) == { + "count_a": [5], + "count_b": [5], + "count_c": [3], + "count_star": [5], + "count_null": [0], + "count_unique_a": [5], + "count_unique_b": [3], + "count_unique_c": [2], + "count_unique_null": [0], + } + + df = pl.DataFrame({"x": [None, None, None]}) + res = df.sql( + """ + SELECT + COUNT(x) AS count_x, + COUNT(*) AS count_star, + COUNT(DISTINCT x) AS count_unique_x + FROM self + """ + ) + assert res.to_dict(as_series=False) == { + "count_x": [0], + "count_star": [3], + "count_unique_x": [0], + } + + def test_distinct() -> None: df = pl.DataFrame( {