Skip to content

Commit

Permalink
feat(python): add INITCAP string function for SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Aug 9, 2023
1 parent ae38e39 commit baebab2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
2 changes: 1 addition & 1 deletion crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ polars-plan = { version = "0.31.1", path = "../polars-plan", features = ["compil
serde = "1"
serde_json = { version = "1" }
# sqlparser = { git = "https://github.com/sqlparser-rs/sqlparser-rs.git", rev = "ae3b5844c839072c235965fe0d1bddc473dced87" }
sqlparser = "0.34"
sqlparser = "0.36.1"

[features]
csv = ["polars-lazy/csv"]
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ pub(crate) enum PolarsSqlFunctions {
/// SELECT column_2 from df WHERE ENDS_WITH(column_1, 'a');
/// ```
EndsWith,
/// SQL 'initcap' function
/// ```sql
/// SELECT INITCAP(column_1) from df;
/// ```
InitCap,
/// SQL 'left' function
/// ```sql
/// SELECT LEFT(column_1, 3) from df;
Expand Down Expand Up @@ -475,6 +480,7 @@ impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions {
// String functions
// ----
"ends_with" => Self::EndsWith,
"initcap" => Self::InitCap,
"length" => Self::Length,
"left" => Self::Left,
"lower" => Self::Lower,
Expand Down Expand Up @@ -578,6 +584,7 @@ impl SqlFunctionVisitor<'_> {
// String functions
// ----
EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)),
InitCap => self.visit_unary(|e| e.str().to_titlecase()),
Left => self.try_visit_binary(|e, length| {
Ok(e.str().str_slice(0, match length {
Expr::Literal(LiteralValue::Int64(n)) => Some(n as u64),
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,29 @@ def test_sql_round_ndigits_errors() -> None:
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")


def test_sql_string_case() -> None:
df = pl.DataFrame({"words": ["Test SOME words"]})

with pl.SQLContext(frame=df) as ctx:
res = ctx.execute(
"""
SELECT
words,
INITCAP(words) as cap,
UPPER(words) as upper,
LOWER(words) as lower,
FROM frame
"""
).collect()

assert res.to_dict(False) == {
"words": ["Test SOME words"],
"cap": ["Test Some Words"],
"upper": ["TEST SOME WORDS"],
"lower": ["test some words"],
}


def test_sql_string_lengths() -> None:
df = pl.DataFrame({"words": ["Café", None, "東京"]})

Expand Down

0 comments on commit baebab2

Please sign in to comment.