Skip to content

Commit

Permalink
feat: Add SQL support for INTERSECT and EXCEPT ops (pola-rs#16960)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored and Wouittone committed Jun 22, 2024
1 parent 020c075 commit 32c8516
Show file tree
Hide file tree
Showing 18 changed files with 886 additions and 487 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ simd-json = { version = "0.13", features = ["known-key"] }
simdutf8 = "0.1.4"
slotmap = "1"
smartstring = "1"
sqlparser = "0.45"
sqlparser = "0.47"
stacker = "0.1"
streaming-iterator = "0.1.9"
strength_reduce = "0.2"
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ fmt: ## Run autoformatting and linting
pre-commit: fmt clippy clippy-default ## Run all code quality checks

.PHONY: clean
clean: ## Clean up caches and build artifacts
clean: ## Clean up caches, build artifacts, and the venv
@$(MAKE) -s -C py-polars/ $@
@rm -rf .ruff_cache/
@rm -rf .hypothesis/
Expand Down
140 changes: 110 additions & 30 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl SQLContext {
if let SQLExpr::Value(SQLValue::Number(ref idx, _)) = **expr {
Err(polars_err!(
SQLSyntax:
"negative ordinals values are invalid for {}; found -{}",
"negative ordinal values are invalid for {}; found -{}",
clause,
idx
))
Expand All @@ -225,7 +225,7 @@ impl SQLContext {
let idx = idx.parse::<usize>().map_err(|_| {
polars_err!(
SQLSyntax:
"negative ordinals values are invalid for {}; found {}",
"negative ordinal values are invalid for {}; found {}",
clause,
idx
)
Expand Down Expand Up @@ -273,50 +273,131 @@ impl SQLContext {
left,
right,
} => self.process_union(left, right, set_quantifier, query),
SetExpr::SetOperation { op, .. } => {
polars_bail!(SQLInterface: "'{}' operation not yet supported", op)
},

#[cfg(feature = "semi_anti_join")]
SetExpr::SetOperation {
op: SetOperator::Intersect | SetOperator::Except,
set_quantifier,
left,
right,
} => self.process_except_intersect(left, right, set_quantifier, query),

SetExpr::Values(Values {
explicit_row: _,
rows,
}) => self.process_values(rows),
op => polars_bail!(SQLInterface: "'{}' operation not yet supported", op),

SetExpr::Table(tbl) => {
if tbl.table_name.is_some() {
let table_name = tbl.table_name.as_ref().unwrap();
self.get_table_from_current_scope(table_name)
.ok_or_else(|| {
polars_err!(
SQLInterface: "no table or alias named '{}' found",
tbl
)
})
} else {
polars_bail!(SQLInterface: "'TABLE' requires valid table name")
}
},
op => {
let op = match op {
SetExpr::SetOperation { op, .. } => op,
_ => unreachable!(),
};
polars_bail!(SQLInterface: "'{}' operation is currently unsupported", op)
},
}
}

#[cfg(feature = "semi_anti_join")]
fn process_except_intersect(
&mut self,
left: &SetExpr,
right: &SetExpr,
quantifier: &SetQuantifier,
query: &Query,
) -> PolarsResult<LazyFrame> {
let (join_type, op_name) = match *query.body {
SetExpr::SetOperation {
op: SetOperator::Except,
..
} => (JoinType::Anti, "EXCEPT"),
_ => (JoinType::Semi, "INTERSECT"),
};
let mut lf = self.process_set_expr(left, query)?;
let mut rf = self.process_set_expr(right, query)?;
let join = lf
.clone()
.join_builder()
.with(rf.clone())
.how(join_type)
.join_nulls(true);

let lf_schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?;
let lf_cols: Vec<_> = lf_schema.iter_names().map(|nm| col(nm)).collect();
let joined_tbl = match quantifier {
SetQuantifier::ByName | SetQuantifier::AllByName => {
// note: 'BY NAME' is pending https://github.com/sqlparser-rs/sqlparser-rs/pull/1309
join.on(lf_cols).finish()
},
SetQuantifier::Distinct | SetQuantifier::None => {
let rf_schema = rf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?;
let rf_cols: Vec<_> = rf_schema.iter_names().map(|nm| col(nm)).collect();
if lf_cols.len() != rf_cols.len() {
polars_bail!(SQLInterface: "{} requires equal number of columns in each table (use '{} BY NAME' to combine mismatched tables)", op_name, op_name)
}
join.left_on(lf_cols).right_on(rf_cols).finish()
},
_ => {
polars_bail!(SQLInterface: "'{} {}' is not supported", op_name, quantifier.to_string())
},
};
Ok(joined_tbl.unique(None, UniqueKeepStrategy::Any))
}

fn process_union(
&mut self,
left: &SetExpr,
right: &SetExpr,
quantifier: &SetQuantifier,
query: &Query,
) -> PolarsResult<LazyFrame> {
let left = self.process_set_expr(left, query)?;
let right = self.process_set_expr(right, query)?;
let mut lf = self.process_set_expr(left, query)?;
let mut rf = self.process_set_expr(right, query)?;
let opts = UnionArgs {
parallel: true,
to_supertypes: true,
..Default::default()
};
match quantifier {
// UNION ALL
SetQuantifier::All => polars_lazy::dsl::concat(vec![left, right], opts),
// UNION [DISTINCT]
SetQuantifier::Distinct | SetQuantifier::None => {
let concatenated = polars_lazy::dsl::concat(vec![left, right], opts);
concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any))
// UNION [ALL | DISTINCT]
SetQuantifier::All | SetQuantifier::Distinct | SetQuantifier::None => {
let lf_schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?;
let rf_schema = rf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?;
if lf_schema.len() != rf_schema.len() {
polars_bail!(SQLInterface: "UNION requires equal number of columns in each table (use 'UNION BY NAME' to combine mismatched tables)")
}
let concatenated = polars_lazy::dsl::concat(vec![lf, rf], opts);
match quantifier {
SetQuantifier::Distinct | SetQuantifier::None => {
concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any))
},
_ => concatenated,
}
},
// UNION ALL BY NAME
#[cfg(feature = "diagonal_concat")]
SetQuantifier::AllByName => concat_lf_diagonal(vec![left, right], opts),
SetQuantifier::AllByName => concat_lf_diagonal(vec![lf, rf], opts),
// UNION [DISTINCT] BY NAME
#[cfg(feature = "diagonal_concat")]
SetQuantifier::ByName | SetQuantifier::DistinctByName => {
let concatenated = concat_lf_diagonal(vec![left, right], opts);
let concatenated = concat_lf_diagonal(vec![lf, rf], opts);
concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any))
},
#[allow(unreachable_patterns)]
_ => polars_bail!(SQLInterface: "'UNION {}' is not yet supported", quantifier),
_ => polars_bail!(SQLInterface: "'UNION {}' is not currently supported", quantifier),
}
}

Expand Down Expand Up @@ -466,7 +547,7 @@ impl SQLContext {
join_type => {
polars_bail!(
SQLInterface:
"join type '{:?}' not yet supported by polars-sql", join_type
"join type '{:?}' not currently supported", join_type
);
},
};
Expand Down Expand Up @@ -763,7 +844,7 @@ impl SQLContext {
let tbl_name = name.0.first().unwrap().value.as_str();
// CREATE TABLE IF NOT EXISTS
if *if_not_exists && self.table_map.contains_key(tbl_name) {
polars_bail!(SQLInterface: "relation {} already exists", tbl_name);
polars_bail!(SQLInterface: "relation '{}' already exists", tbl_name);
// CREATE OR REPLACE TABLE
}
if let Some(query) = query {
Expand All @@ -776,7 +857,7 @@ impl SQLContext {
.lazy();
Ok(out)
} else {
polars_bail!(SQLInterface: "only `CREATE TABLE AS SELECT` is currently supported");
polars_bail!(SQLInterface: "only `CREATE TABLE AS SELECT ...` is currently supported");
}
} else {
unreachable!()
Expand Down Expand Up @@ -881,8 +962,7 @@ impl SQLContext {
polars_bail!(SQLSyntax: "UNNEST table must have an alias");
}
},

// Support bare table, optional with alias for now
// Support bare table, optionally with an alias, for now
_ => polars_bail!(SQLInterface: "not yet implemented: {}", relation),
}
}
Expand All @@ -909,11 +989,11 @@ impl SQLContext {
fn process_order_by(
&mut self,
mut lf: LazyFrame,
ob: &[OrderByExpr],
order_by: &[OrderByExpr],
) -> PolarsResult<LazyFrame> {
let mut by = Vec::with_capacity(ob.len());
let mut descending = Vec::with_capacity(ob.len());
let mut nulls_last = Vec::with_capacity(ob.len());
let mut by = Vec::with_capacity(order_by.len());
let mut descending = Vec::with_capacity(order_by.len());
let mut nulls_last = Vec::with_capacity(order_by.len());

let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?);
let column_names = schema
Expand All @@ -923,7 +1003,7 @@ impl SQLContext {
.map(|e| col(e))
.collect::<Vec<_>>();

for ob in ob {
for ob in order_by {
// note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise
// https://www.postgresql.org/docs/current/queries-order.html
let desc_order = !ob.asc.unwrap_or(true);
Expand Down Expand Up @@ -951,7 +1031,7 @@ impl SQLContext {
) -> PolarsResult<LazyFrame> {
polars_ensure!(
!contains_wildcard,
SQLSyntax: "GROUP BY error: cannot process wildcard in group_by"
SQLSyntax: "GROUP BY error (cannot process wildcard in group_by)"
);
let schema_before = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?;
let group_by_keys_schema =
Expand Down Expand Up @@ -1082,7 +1162,7 @@ impl SQLContext {
cols(schema.iter_names())
},
e => polars_bail!(
SQLSyntax: "invalid wildcard expression: {:?}",
SQLSyntax: "invalid wildcard expression ({:?})",
e
),
};
Expand All @@ -1096,7 +1176,7 @@ impl SQLContext {
contains_wildcard_exclude: &mut bool,
) -> PolarsResult<Expr> {
if options.opt_except.is_some() {
polars_bail!(SQLSyntax: "EXCEPT not supported; use EXCLUDE instead")
polars_bail!(SQLSyntax: "EXCEPT not supported (use EXCLUDE instead)")
}
Ok(match &options.opt_exclude {
Some(ExcludeSelectItem::Single(ident)) => {
Expand Down
Loading

0 comments on commit 32c8516

Please sign in to comment.