diff --git a/Cargo.lock b/Cargo.lock index 89d89ecfb053..3131e578e205 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4435,9 +4435,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.47.0" +version = "0.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "295e9930cd7a97e58ca2a070541a3ca502b17f5d1fa7157376d0fabd85324f25" +checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" dependencies = [ "log", ] diff --git a/Cargo.toml b/Cargo.toml index 2efb3e722d49..c10737d04b5a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,7 +79,7 @@ simd-json = { version = "0.13", features = ["known-key"] } simdutf8 = "0.1.4" slotmap = "1" smartstring = "1" -sqlparser = "0.47" +sqlparser = "0.49" stacker = "0.1" streaming-iterator = "0.1.9" strength_reduce = "0.2" diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 3b2c3c43c919..9ff031a0a31c 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -9,9 +9,9 @@ use polars_ops::frame::JoinCoalesce; use polars_plan::dsl::function_expr::StructFunction; use polars_plan::prelude::*; use sqlparser::ast::{ - BinaryOperator, Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident, - JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, - RenameSelectItem, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, + BinaryOperator, CreateTable, Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, + GroupByExpr, Ident, JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderBy, + Query, RenameSelectItem, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value as SQLValue, Values, WildcardAdditionalOptions, }; @@ -382,10 +382,7 @@ impl SQLContext { let lf_schema = self.get_frame_schema(&mut lf)?; let lf_cols: Vec<_> = lf_schema.iter_names().map(|nm| col(nm)).collect(); let joined_tbl = match quantifier { - SetQuantifier::ByName | SetQuantifier::AllByName => { - // note: 'BY NAME' is pending https://github.com/sqlparser-rs/sqlparser-rs/pull/1309 - join.on(lf_cols).finish() - }, + SetQuantifier::ByName | SetQuantifier::AllByName => join.on(lf_cols).finish(), SetQuantifier::Distinct | SetQuantifier::None => { let rf_schema = self.get_frame_schema(&mut rf)?; let rf_cols: Vec<_> = rf_schema.iter_names().map(|nm| col(nm)).collect(); @@ -658,7 +655,10 @@ impl SQLContext { let mut group_by_keys: Vec = Vec::new(); match &select_stmt.group_by { // Standard "GROUP BY x, y, z" syntax (also recognising ordinal values) - GroupByExpr::Expressions(group_by_exprs) => { + GroupByExpr::Expressions(group_by_exprs, modifiers) => { + if !modifiers.is_empty() { + polars_bail!(SQLInterface: "GROUP BY does not support CUBE, ROLLUP, or TOTALS modifiers") + } // translate the group expressions, allowing ordinal values group_by_keys = group_by_exprs .iter() @@ -675,7 +675,10 @@ impl SQLContext { }, // "GROUP BY ALL" syntax; automatically adds expressions that do not contain // nested agg/window funcs to the group key (also ignores literals). - GroupByExpr::All => { + GroupByExpr::All(modifiers) => { + if !modifiers.is_empty() { + polars_bail!(SQLInterface: "GROUP BY does not support CUBE, ROLLUP, or TOTALS modifiers") + } projections.iter().for_each(|expr| match expr { // immediately match the most common cases (col|agg|len|lit, optionally aliased). Expr::Agg(_) | Expr::Len | Expr::Literal(_) => (), @@ -704,7 +707,7 @@ impl SQLContext { lf = if group_by_keys.is_empty() { // Final/selected cols, accounting for 'SELECT *' modifiers let mut retained_cols = Vec::with_capacity(projections.len()); - let have_order_by = !query.order_by.is_empty(); + let have_order_by = query.order_by.is_some(); // Note: if there is an 'order by' then we project everything (original cols // and new projections) and *then* select the final cols; the retained cols @@ -736,9 +739,8 @@ impl SQLContext { if !select_modifiers.rename.is_empty() { lf = lf.with_columns(select_modifiers.renamed_cols()); } - if have_order_by { - lf = self.process_order_by(lf, &query.order_by, Some(&retained_cols))? - } + + lf = self.process_order_by(lf, &query.order_by, Some(&retained_cols))?; lf = lf.select(retained_cols); if !select_modifiers.rename.is_empty() { @@ -779,9 +781,7 @@ impl SQLContext { .collect::>>()?; // DISTINCT ON has to apply the ORDER BY before the operation. - if !query.order_by.is_empty() { - lf = self.process_order_by(lf, &query.order_by, None)?; - } + lf = self.process_order_by(lf, &query.order_by, None)?; return Ok(lf.unique_stable(Some(cols), UniqueKeepStrategy::First)); }, None => lf, @@ -903,12 +903,12 @@ impl SQLContext { } fn execute_create_table(&mut self, stmt: &Statement) -> PolarsResult { - if let Statement::CreateTable { + if let Statement::CreateTable(CreateTable { if_not_exists, name, query, .. - } = stmt + }) = stmt { let tbl_name = name.0.first().unwrap().value.as_str(); // CREATE TABLE IF NOT EXISTS @@ -976,6 +976,7 @@ impl SQLContext { array_exprs, with_offset, with_offset_alias: _, + .. } => { if let Some(alias) = alias { let table_name = alias.name.value.clone(); @@ -1021,8 +1022,8 @@ impl SQLContext { let lf = DataFrame::new(column_series)?.lazy(); if *with_offset { - // TODO: make a PR to `sqlparser-rs` to support 'ORDINALITY' - // (note that 'OFFSET' is BigQuery-specific syntax, not PostgreSQL) + // TODO: support 'WITH ORDINALITY' modifier. + // (note that 'WITH OFFSET' is BigQuery-specific syntax, not PostgreSQL) polars_bail!(SQLInterface: "UNNEST tables do not (yet) support WITH OFFSET/ORDINALITY"); } self.table_map.insert(table_name.clone(), lf.clone()); @@ -1058,12 +1059,16 @@ impl SQLContext { fn process_order_by( &mut self, mut lf: LazyFrame, - order_by: &[OrderByExpr], + order_by: &Option, selected: Option<&[Expr]>, ) -> PolarsResult { + if order_by.as_ref().map_or(true, |ob| ob.exprs.is_empty()) { + return Ok(lf); + } let schema = self.get_frame_schema(&mut lf)?; let columns_iter = schema.iter_names().map(|e| col(e)); + let order_by = order_by.as_ref().unwrap().exprs.clone(); let mut descending = Vec::with_capacity(order_by.len()); let mut nulls_last = Vec::with_capacity(order_by.len()); let mut by: Vec = Vec::with_capacity(order_by.len()); @@ -1262,9 +1267,6 @@ impl SQLContext { polars_bail!(SQLInterface: "EXCLUDE and EXCEPT wildcard options cannot be used together (prefer EXCLUDE)") } else if options.opt_exclude.is_some() && options.opt_ilike.is_some() { polars_bail!(SQLInterface: "EXCLUDE and ILIKE wildcard options cannot be used together") - } else if options.opt_rename.is_some() && options.opt_replace.is_some() { - // pending an upstream fix: https://github.com/sqlparser-rs/sqlparser-rs/pull/1321 - polars_bail!(SQLInterface: "RENAME and REPLACE wildcard options cannot (yet) be used together") } // SELECT * EXCLUDE diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index da6f24904d5d..ffcb3dbc4998 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -115,11 +115,13 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult DataType::UInt8, // see also: "custom" types below SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => DataType::UInt32, SQLDataType::UnsignedInt2(_) | SQLDataType::UnsignedSmallInt(_) => DataType::UInt16, SQLDataType::UnsignedInt4(_) | SQLDataType::UnsignedMediumInt(_) => DataType::UInt32, - SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) => DataType::UInt64, - SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, // see also: "custom" types below + SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) | SQLDataType::UInt8 => { + DataType::UInt64 + }, // --------------------------------- // float @@ -562,18 +564,34 @@ impl SQLExprVisitor<'_> { match left_dtype { DataType::Time if is_iso_time(s) => { - right.clone().strict_cast(left_dtype.clone()) + right.clone().str().to_time(StrptimeOptions { + strict: true, + ..Default::default() + }) }, DataType::Date if is_iso_date(s) => { - right.clone().strict_cast(left_dtype.clone()) + right.clone().str().to_date(StrptimeOptions { + strict: true, + ..Default::default() + }) }, - DataType::Datetime(_, _) if is_iso_datetime(s) || is_iso_date(s) => { + DataType::Datetime(tu, tz) if is_iso_datetime(s) || is_iso_date(s) => { if s.len() == 10 { // handle upcast from ISO date string (10 chars) to datetime - lit(format!("{}T00:00:00", s)).strict_cast(left_dtype.clone()) + lit(format!("{}T00:00:00", s)) } else { - lit(s.replacen(' ', "T", 1)).strict_cast(left_dtype.clone()) + lit(s.replacen(' ', "T", 1)) } + .str() + .to_datetime( + Some(*tu), + tz.clone(), + StrptimeOptions { + strict: true, + ..Default::default() + }, + lit("latest"), + ) }, _ => right.clone(), } @@ -834,13 +852,17 @@ impl SQLExprVisitor<'_> { (dtype_expr_match, self.active_schema.as_ref()) { if elems.dtype() == &DataType::String { - if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) = - schema.get(name) - { - elems = elems.strict_cast(&schema.get(name).unwrap().clone())?; + if let Some(dtype) = schema.get(name) { + if matches!( + dtype, + DataType::Date | DataType::Time | DataType::Datetime(_, _) + ) { + elems = elems.strict_cast(dtype)?; + } } } } + // if we are parsing the list as an element in a series, implode. // otherwise, return the series as-is. let res = if result_as_element { diff --git a/py-polars/tests/unit/sql/test_set_ops.py b/py-polars/tests/unit/sql/test_set_ops.py index 10b4bccea120..64508887d1c5 100644 --- a/py-polars/tests/unit/sql/test_set_ops.py +++ b/py-polars/tests/unit/sql/test_set_ops.py @@ -39,6 +39,36 @@ def test_except_intersect() -> None: assert_frame_equal(pl.DataFrame({"x": [2], "y": [2]}), res) +def test_except_intersect_by_name() -> None: + df1 = pl.DataFrame( # noqa: F841 + { + "x": [1, 9, 1, 1], + "y": [2, 3, 4, 4], + "z": [5, 5, 5, 5], + } + ) + df2 = pl.DataFrame( # noqa: F841 + { + "y": [2, None, 4], + "w": ["?", "!", "%"], + "z": [7, 6, 5], + "x": [1, 9, 1], + } + ) + res_e = pl.sql( + "SELECT x, y, z FROM df1 EXCEPT BY NAME SELECT * FROM df2", + eager=True, + ) + res_i = pl.sql( + "SELECT * FROM df1 INTERSECT BY NAME SELECT * FROM df2", + eager=True, + ) + assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)] + assert sorted(res_i.rows()) == [(1, 4, 5)] + assert res_e.columns == ["x", "y", "z"] + assert res_i.columns == ["x", "y", "z"] + + @pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"]) def test_except_intersect_errors(op: str) -> None: df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841 diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index 90c1aa1384da..7cae56b24948 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -196,6 +196,9 @@ def test_extract_century_millennium(dt: date, expected: list[int]) -> None: ("dtm > '2006-01-01'", [0, 1, 2]), # << implies '2006-01-01 00:00:00' ("dtm <= '2006-01-01'", []), # << implies '2006-01-01 00:00:00' ("dt != '1960-01-07'", [0, 1]), + ("tm != '22:10:30'", [0, 2]), + ("tm >= '11:00:00' AND tm < '22:00:00'", [0]), + ("tm BETWEEN '12:00:00' AND '23:59:58'", [0, 1]), ("dt BETWEEN '2050-01-01' AND '2100-12-31'", [1]), ("dt::datetime = '1960-01-07'", [2]), ("dt::datetime = '1960-01-07 00:00:00'", [2]), @@ -221,6 +224,11 @@ def test_implicit_temporal_strings(constraint: str, expected: list[int]) -> None date(2077, 1, 1), date(1960, 1, 7), ], + "tm": [ + time(17, 30, 45), + time(22, 10, 30), + time(10, 25, 15), + ], } ) res = df.sql(f"SELECT idx FROM self WHERE {constraint}") @@ -447,6 +455,6 @@ def test_timestamp_time_unit_errors() -> None: with pytest.raises( SQLInterfaceError, - match="sql parser error: Expected literal int, found: - ", + match="sql parser error: Expected: literal int, found: - ", ): ctx.execute("SELECT ts::timestamp(-3) FROM frame_data") diff --git a/py-polars/tests/unit/sql/test_wildcard_opts.py b/py-polars/tests/unit/sql/test_wildcard_opts.py index 7e704b5858cf..d1fd6873bc3c 100644 --- a/py-polars/tests/unit/sql/test_wildcard_opts.py +++ b/py-polars/tests/unit/sql/test_wildcard_opts.py @@ -131,6 +131,11 @@ def test_select_rename_exclude_sort(order_by: str, df: pl.DataFrame) -> None: ["ID"], [(333,), (222,), (111,)], ), + ( + "(ID // 3 AS ID) RENAME (ID AS Identifier)", + ["Identifier"], + [(333,), (222,), (111,)], + ), ( "((City || ':' || City) AS City, ID // -3 AS ID)", ["City", "ID"], @@ -151,10 +156,13 @@ def test_select_replace( for order_by in ("", "ORDER BY ID DESC", "ORDER BY -ID ASC"): res = df.sql(f"SELECT * REPLACE {replacements} FROM self {order_by}") if not order_by: - res = res.sort("ID", descending=True) + res = res.sort(check_cols[-1], descending=True) assert res.select(check_cols).rows() == expected - assert res.columns == df.columns + expected_columns = ( + check_cols + df.columns[1:] if check_cols == ["Identifier"] else df.columns + ) + assert res.columns == expected_columns def test_select_wildcard_errors(df: pl.DataFrame) -> None: