Skip to content

Commit

Permalink
feat: Support "BY NAME" qualifier for SQL "INTERSECT" and "EXCEPT" …
Browse files Browse the repository at this point in the history
…set ops (#17835)
  • Loading branch information
alexander-beedie authored Jul 25, 2024
1 parent 9625c82 commit 8373cdb
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 41 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ simd-json = { version = "0.13", features = ["known-key"] }
simdutf8 = "0.1.4"
slotmap = "1"
smartstring = "1"
sqlparser = "0.47"
sqlparser = "0.49"
stacker = "0.1"
streaming-iterator = "0.1.9"
strength_reduce = "0.2"
Expand Down
50 changes: 26 additions & 24 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ use polars_ops::frame::JoinCoalesce;
use polars_plan::dsl::function_expr::StructFunction;
use polars_plan::prelude::*;
use sqlparser::ast::{
BinaryOperator, Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, Ident,
JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query,
RenameSelectItem, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement,
BinaryOperator, CreateTable, Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg,
GroupByExpr, Ident, JoinConstraint, JoinOperator, ObjectName, ObjectType, Offset, OrderBy,
Query, RenameSelectItem, Select, SelectItem, SetExpr, SetOperator, SetQuantifier, Statement,
TableAlias, TableFactor, TableWithJoins, UnaryOperator, Value as SQLValue, Values,
WildcardAdditionalOptions,
};
Expand Down Expand Up @@ -382,10 +382,7 @@ impl SQLContext {
let lf_schema = self.get_frame_schema(&mut lf)?;
let lf_cols: Vec<_> = lf_schema.iter_names().map(|nm| col(nm)).collect();
let joined_tbl = match quantifier {
SetQuantifier::ByName | SetQuantifier::AllByName => {
// note: 'BY NAME' is pending https://github.com/sqlparser-rs/sqlparser-rs/pull/1309
join.on(lf_cols).finish()
},
SetQuantifier::ByName | SetQuantifier::AllByName => join.on(lf_cols).finish(),
SetQuantifier::Distinct | SetQuantifier::None => {
let rf_schema = self.get_frame_schema(&mut rf)?;
let rf_cols: Vec<_> = rf_schema.iter_names().map(|nm| col(nm)).collect();
Expand Down Expand Up @@ -658,7 +655,10 @@ impl SQLContext {
let mut group_by_keys: Vec<Expr> = Vec::new();
match &select_stmt.group_by {
// Standard "GROUP BY x, y, z" syntax (also recognising ordinal values)
GroupByExpr::Expressions(group_by_exprs) => {
GroupByExpr::Expressions(group_by_exprs, modifiers) => {
if !modifiers.is_empty() {
polars_bail!(SQLInterface: "GROUP BY does not support CUBE, ROLLUP, or TOTALS modifiers")
}
// translate the group expressions, allowing ordinal values
group_by_keys = group_by_exprs
.iter()
Expand All @@ -675,7 +675,10 @@ impl SQLContext {
},
// "GROUP BY ALL" syntax; automatically adds expressions that do not contain
// nested agg/window funcs to the group key (also ignores literals).
GroupByExpr::All => {
GroupByExpr::All(modifiers) => {
if !modifiers.is_empty() {
polars_bail!(SQLInterface: "GROUP BY does not support CUBE, ROLLUP, or TOTALS modifiers")
}
projections.iter().for_each(|expr| match expr {
// immediately match the most common cases (col|agg|len|lit, optionally aliased).
Expr::Agg(_) | Expr::Len | Expr::Literal(_) => (),
Expand Down Expand Up @@ -704,7 +707,7 @@ impl SQLContext {
lf = if group_by_keys.is_empty() {
// Final/selected cols, accounting for 'SELECT *' modifiers
let mut retained_cols = Vec::with_capacity(projections.len());
let have_order_by = !query.order_by.is_empty();
let have_order_by = query.order_by.is_some();

// Note: if there is an 'order by' then we project everything (original cols
// and new projections) and *then* select the final cols; the retained cols
Expand Down Expand Up @@ -736,9 +739,8 @@ impl SQLContext {
if !select_modifiers.rename.is_empty() {
lf = lf.with_columns(select_modifiers.renamed_cols());
}
if have_order_by {
lf = self.process_order_by(lf, &query.order_by, Some(&retained_cols))?
}

lf = self.process_order_by(lf, &query.order_by, Some(&retained_cols))?;
lf = lf.select(retained_cols);

if !select_modifiers.rename.is_empty() {
Expand Down Expand Up @@ -779,9 +781,7 @@ impl SQLContext {
.collect::<PolarsResult<Vec<_>>>()?;

// DISTINCT ON has to apply the ORDER BY before the operation.
if !query.order_by.is_empty() {
lf = self.process_order_by(lf, &query.order_by, None)?;
}
lf = self.process_order_by(lf, &query.order_by, None)?;
return Ok(lf.unique_stable(Some(cols), UniqueKeepStrategy::First));
},
None => lf,
Expand Down Expand Up @@ -903,12 +903,12 @@ impl SQLContext {
}

fn execute_create_table(&mut self, stmt: &Statement) -> PolarsResult<LazyFrame> {
if let Statement::CreateTable {
if let Statement::CreateTable(CreateTable {
if_not_exists,
name,
query,
..
} = stmt
}) = stmt
{
let tbl_name = name.0.first().unwrap().value.as_str();
// CREATE TABLE IF NOT EXISTS
Expand Down Expand Up @@ -976,6 +976,7 @@ impl SQLContext {
array_exprs,
with_offset,
with_offset_alias: _,
..
} => {
if let Some(alias) = alias {
let table_name = alias.name.value.clone();
Expand Down Expand Up @@ -1021,8 +1022,8 @@ impl SQLContext {

let lf = DataFrame::new(column_series)?.lazy();
if *with_offset {
// TODO: make a PR to `sqlparser-rs` to support 'ORDINALITY'
// (note that 'OFFSET' is BigQuery-specific syntax, not PostgreSQL)
// TODO: support 'WITH ORDINALITY' modifier.
// (note that 'WITH OFFSET' is BigQuery-specific syntax, not PostgreSQL)
polars_bail!(SQLInterface: "UNNEST tables do not (yet) support WITH OFFSET/ORDINALITY");
}
self.table_map.insert(table_name.clone(), lf.clone());
Expand Down Expand Up @@ -1058,12 +1059,16 @@ impl SQLContext {
fn process_order_by(
&mut self,
mut lf: LazyFrame,
order_by: &[OrderByExpr],
order_by: &Option<OrderBy>,
selected: Option<&[Expr]>,
) -> PolarsResult<LazyFrame> {
if order_by.as_ref().map_or(true, |ob| ob.exprs.is_empty()) {
return Ok(lf);
}
let schema = self.get_frame_schema(&mut lf)?;
let columns_iter = schema.iter_names().map(|e| col(e));

let order_by = order_by.as_ref().unwrap().exprs.clone();
let mut descending = Vec::with_capacity(order_by.len());
let mut nulls_last = Vec::with_capacity(order_by.len());
let mut by: Vec<Expr> = Vec::with_capacity(order_by.len());
Expand Down Expand Up @@ -1262,9 +1267,6 @@ impl SQLContext {
polars_bail!(SQLInterface: "EXCLUDE and EXCEPT wildcard options cannot be used together (prefer EXCLUDE)")
} else if options.opt_exclude.is_some() && options.opt_ilike.is_some() {
polars_bail!(SQLInterface: "EXCLUDE and ILIKE wildcard options cannot be used together")
} else if options.opt_rename.is_some() && options.opt_replace.is_some() {
// pending an upstream fix: https://github.com/sqlparser-rs/sqlparser-rs/pull/1321
polars_bail!(SQLInterface: "RENAME and REPLACE wildcard options cannot (yet) be used together")
}

// SELECT * EXCLUDE
Expand Down
44 changes: 33 additions & 11 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,13 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<D
// unsigned integer: the following do not map to PostgreSQL types/syntax, but
// are enabled for wider compatibility (eg: "CAST(col AS BIGINT UNSIGNED)").
// ---------------------------------
SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, // see also: "custom" types below
SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) => DataType::UInt32,
SQLDataType::UnsignedInt2(_) | SQLDataType::UnsignedSmallInt(_) => DataType::UInt16,
SQLDataType::UnsignedInt4(_) | SQLDataType::UnsignedMediumInt(_) => DataType::UInt32,
SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) => DataType::UInt64,
SQLDataType::UnsignedTinyInt(_) => DataType::UInt8, // see also: "custom" types below
SQLDataType::UnsignedInt8(_) | SQLDataType::UnsignedBigInt(_) | SQLDataType::UInt8 => {
DataType::UInt64
},

// ---------------------------------
// float
Expand Down Expand Up @@ -562,18 +564,34 @@ impl SQLExprVisitor<'_> {

match left_dtype {
DataType::Time if is_iso_time(s) => {
right.clone().strict_cast(left_dtype.clone())
right.clone().str().to_time(StrptimeOptions {
strict: true,
..Default::default()
})
},
DataType::Date if is_iso_date(s) => {
right.clone().strict_cast(left_dtype.clone())
right.clone().str().to_date(StrptimeOptions {
strict: true,
..Default::default()
})
},
DataType::Datetime(_, _) if is_iso_datetime(s) || is_iso_date(s) => {
DataType::Datetime(tu, tz) if is_iso_datetime(s) || is_iso_date(s) => {
if s.len() == 10 {
// handle upcast from ISO date string (10 chars) to datetime
lit(format!("{}T00:00:00", s)).strict_cast(left_dtype.clone())
lit(format!("{}T00:00:00", s))
} else {
lit(s.replacen(' ', "T", 1)).strict_cast(left_dtype.clone())
lit(s.replacen(' ', "T", 1))
}
.str()
.to_datetime(
Some(*tu),
tz.clone(),
StrptimeOptions {
strict: true,
..Default::default()
},
lit("latest"),
)
},
_ => right.clone(),
}
Expand Down Expand Up @@ -834,13 +852,17 @@ impl SQLExprVisitor<'_> {
(dtype_expr_match, self.active_schema.as_ref())
{
if elems.dtype() == &DataType::String {
if let Some(DataType::Date | DataType::Time | DataType::Datetime(_, _)) =
schema.get(name)
{
elems = elems.strict_cast(&schema.get(name).unwrap().clone())?;
if let Some(dtype) = schema.get(name) {
if matches!(
dtype,
DataType::Date | DataType::Time | DataType::Datetime(_, _)
) {
elems = elems.strict_cast(dtype)?;
}
}
}
}

// if we are parsing the list as an element in a series, implode.
// otherwise, return the series as-is.
let res = if result_as_element {
Expand Down
30 changes: 30 additions & 0 deletions py-polars/tests/unit/sql/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,36 @@ def test_except_intersect() -> None:
assert_frame_equal(pl.DataFrame({"x": [2], "y": [2]}), res)


def test_except_intersect_by_name() -> None:
df1 = pl.DataFrame( # noqa: F841
{
"x": [1, 9, 1, 1],
"y": [2, 3, 4, 4],
"z": [5, 5, 5, 5],
}
)
df2 = pl.DataFrame( # noqa: F841
{
"y": [2, None, 4],
"w": ["?", "!", "%"],
"z": [7, 6, 5],
"x": [1, 9, 1],
}
)
res_e = pl.sql(
"SELECT x, y, z FROM df1 EXCEPT BY NAME SELECT * FROM df2",
eager=True,
)
res_i = pl.sql(
"SELECT * FROM df1 INTERSECT BY NAME SELECT * FROM df2",
eager=True,
)
assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)]
assert sorted(res_i.rows()) == [(1, 4, 5)]
assert res_e.columns == ["x", "y", "z"]
assert res_i.columns == ["x", "y", "z"]


@pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"])
def test_except_intersect_errors(op: str) -> None:
df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841
Expand Down
10 changes: 9 additions & 1 deletion py-polars/tests/unit/sql/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def test_extract_century_millennium(dt: date, expected: list[int]) -> None:
("dtm > '2006-01-01'", [0, 1, 2]), # << implies '2006-01-01 00:00:00'
("dtm <= '2006-01-01'", []), # << implies '2006-01-01 00:00:00'
("dt != '1960-01-07'", [0, 1]),
("tm != '22:10:30'", [0, 2]),
("tm >= '11:00:00' AND tm < '22:00:00'", [0]),
("tm BETWEEN '12:00:00' AND '23:59:58'", [0, 1]),
("dt BETWEEN '2050-01-01' AND '2100-12-31'", [1]),
("dt::datetime = '1960-01-07'", [2]),
("dt::datetime = '1960-01-07 00:00:00'", [2]),
Expand All @@ -221,6 +224,11 @@ def test_implicit_temporal_strings(constraint: str, expected: list[int]) -> None
date(2077, 1, 1),
date(1960, 1, 7),
],
"tm": [
time(17, 30, 45),
time(22, 10, 30),
time(10, 25, 15),
],
}
)
res = df.sql(f"SELECT idx FROM self WHERE {constraint}")
Expand Down Expand Up @@ -447,6 +455,6 @@ def test_timestamp_time_unit_errors() -> None:

with pytest.raises(
SQLInterfaceError,
match="sql parser error: Expected literal int, found: - ",
match="sql parser error: Expected: literal int, found: - ",
):
ctx.execute("SELECT ts::timestamp(-3) FROM frame_data")
12 changes: 10 additions & 2 deletions py-polars/tests/unit/sql/test_wildcard_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def test_select_rename_exclude_sort(order_by: str, df: pl.DataFrame) -> None:
["ID"],
[(333,), (222,), (111,)],
),
(
"(ID // 3 AS ID) RENAME (ID AS Identifier)",
["Identifier"],
[(333,), (222,), (111,)],
),
(
"((City || ':' || City) AS City, ID // -3 AS ID)",
["City", "ID"],
Expand All @@ -151,10 +156,13 @@ def test_select_replace(
for order_by in ("", "ORDER BY ID DESC", "ORDER BY -ID ASC"):
res = df.sql(f"SELECT * REPLACE {replacements} FROM self {order_by}")
if not order_by:
res = res.sort("ID", descending=True)
res = res.sort(check_cols[-1], descending=True)

assert res.select(check_cols).rows() == expected
assert res.columns == df.columns
expected_columns = (
check_cols + df.columns[1:] if check_cols == ["Identifier"] else df.columns
)
assert res.columns == expected_columns


def test_select_wildcard_errors(df: pl.DataFrame) -> None:
Expand Down

0 comments on commit 8373cdb

Please sign in to comment.