diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index 650bca823791a..af5dedf2d038c 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -17,7 +17,7 @@ polars-plan = { version = "0.31.1", path = "../polars-plan", features = ["compil serde = "1" serde_json = { version = "1" } # sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs.git", rev = "ae3b5844c839072c235965fe0d1bddc473dced87" } -sqlparser = "0.34" +sqlparser = "0.36.1" [features] csv = ["polars-lazy/csv"] diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 3c4c33d7030e0..46202289f18b7 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -181,6 +181,11 @@ pub(crate) enum PolarsSqlFunctions { /// SELECT column_2 from df WHERE ENDS_WITH(column_1, 'a'); /// ``` EndsWith, + /// SQL 'initcap' function + /// ```sql + /// SELECT INITCAP(column_1) from df; + /// ``` + InitCap, /// SQL 'left' function /// ```sql /// SELECT LEFT(column_1, 3) from df; @@ -475,6 +480,7 @@ impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions { // String functions // ---- "ends_with" => Self::EndsWith, + "initcap" => Self::InitCap, "length" => Self::Length, "left" => Self::Left, "lower" => Self::Lower, @@ -578,6 +584,7 @@ impl SqlFunctionVisitor<'_> { // String functions // ---- EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)), + InitCap => self.visit_unary(|e| e.str().to_titlecase()), Left => self.try_visit_binary(|e, length| { Ok(e.str().str_slice(0, match length { Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64), diff --git a/py-polars/tests/unit/test_sql.py b/py-polars/tests/unit/test_sql.py index 830e735f30d0f..46f61960b0e2f 100644 --- a/py-polars/tests/unit/test_sql.py +++ b/py-polars/tests/unit/test_sql.py @@ -709,6 +709,29 @@ def test_sql_round_ndigits_errors() -> None: ctx.execute("SELECT ROUND(n,-1) AS n FROM df") +def test_sql_string_case() -> None: + df = pl.DataFrame({"words": ["Test SOME words"]}) + + with pl.SQLContext(frame=df) as ctx: + res = ctx.execute( + """ + SELECT + words, + INITCAP(words) as cap, + UPPER(words) as upper, + LOWER(words) as lower, + FROM frame + """ + ).collect() + + assert res.to_dict(False) == { + "words": ["Test SOME words"], + "cap": ["Test Some Words"], + "upper": ["TEST SOME WORDS"], + "lower": ["test some words"], + } + + def test_sql_string_lengths() -> None: df = pl.DataFrame({"words": ["Café", None, "東京"]})