diff --git a/Cargo.lock b/Cargo.lock index aedf596f0188..d59d832111cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4413,9 +4413,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.45.0" +version = "0.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7bbffee862a796d67959a89859d6b1046bb5016d63e23835ad0da182777bbe0" +checksum = "295e9930cd7a97e58ca2a070541a3ca502b17f5d1fa7157376d0fabd85324f25" dependencies = [ "log", ] diff --git a/Cargo.toml b/Cargo.toml index 98550ca22299..5b50b6e5094d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -78,7 +78,7 @@ simd-json = { version = "0.13", features = ["known-key"] } simdutf8 = "0.1.4" slotmap = "1" smartstring = "1" -sqlparser = "0.45" +sqlparser = "0.47" stacker = "0.1" streaming-iterator = "0.1.9" strength_reduce = "0.2" diff --git a/Makefile b/Makefile index f38958b621c0..2cd38c395a8d 100644 --- a/Makefile +++ b/Makefile @@ -109,7 +109,7 @@ fmt: ## Run autoformatting and linting pre-commit: fmt clippy clippy-default ## Run all code quality checks .PHONY: clean -clean: ## Clean up caches and build artifacts +clean: ## Clean up caches, build artifacts, and the venv @$(MAKE) -s -C py-polars/ $@ @rm -rf .ruff_cache/ @rm -rf .hypothesis/ diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 1570b6e0c919..936f743929c0 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -212,7 +212,7 @@ impl SQLContext { if let SQLExpr::Value(SQLValue::Number(ref idx, _)) = **expr { Err(polars_err!( SQLSyntax: - "negative ordinals values are invalid for {}; found -{}", + "negative ordinal values are invalid for {}; found -{}", clause, idx )) @@ -225,7 +225,7 @@ impl SQLContext { let idx = idx.parse::().map_err(|_| { polars_err!( SQLSyntax: - "negative ordinals values are invalid for {}; found {}", + "negative ordinal values are invalid for {}; found {}", clause, idx ) @@ -273,17 +273,90 @@ impl SQLContext { left, right, } => self.process_union(left, right, set_quantifier, query), - SetExpr::SetOperation { op, .. } => { - polars_bail!(SQLInterface: "'{}' operation not yet supported", op) - }, + + #[cfg(feature = "semi_anti_join")] + SetExpr::SetOperation { + op: SetOperator::Intersect | SetOperator::Except, + set_quantifier, + left, + right, + } => self.process_except_intersect(left, right, set_quantifier, query), + SetExpr::Values(Values { explicit_row: _, rows, }) => self.process_values(rows), - op => polars_bail!(SQLInterface: "'{}' operation not yet supported", op), + + SetExpr::Table(tbl) => { + if tbl.table_name.is_some() { + let table_name = tbl.table_name.as_ref().unwrap(); + self.get_table_from_current_scope(table_name) + .ok_or_else(|| { + polars_err!( + SQLInterface: "no table or alias named '{}' found", + tbl + ) + }) + } else { + polars_bail!(SQLInterface: "'TABLE' requires valid table name") + } + }, + op => { + let op = match op { + SetExpr::SetOperation { op, .. } => op, + _ => unreachable!(), + }; + polars_bail!(SQLInterface: "'{}' operation is currently unsupported", op) + }, } } + #[cfg(feature = "semi_anti_join")] + fn process_except_intersect( + &mut self, + left: &SetExpr, + right: &SetExpr, + quantifier: &SetQuantifier, + query: &Query, + ) -> PolarsResult { + let (join_type, op_name) = match *query.body { + SetExpr::SetOperation { + op: SetOperator::Except, + .. + } => (JoinType::Anti, "EXCEPT"), + _ => (JoinType::Semi, "INTERSECT"), + }; + let mut lf = self.process_set_expr(left, query)?; + let mut rf = self.process_set_expr(right, query)?; + let join = lf + .clone() + .join_builder() + .with(rf.clone()) + .how(join_type) + .join_nulls(true); + + let lf_schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let lf_cols: Vec<_> = lf_schema.iter_names().map(|nm| col(nm)).collect(); + let joined_tbl = match quantifier { + SetQuantifier::ByName | SetQuantifier::AllByName => { + // note: 'BY NAME' is pending https://github.com/sqlparser-rs/sqlparser-rs/pull/1309 + join.on(lf_cols).finish() + }, + SetQuantifier::Distinct | SetQuantifier::None => { + let rf_schema = rf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let rf_cols: Vec<_> = rf_schema.iter_names().map(|nm| col(nm)).collect(); + if lf_cols.len() != rf_cols.len() { + polars_bail!(SQLInterface: "{} requires equal number of columns in each table (use '{} BY NAME' to combine mismatched tables)", op_name, op_name) + } + join.left_on(lf_cols).right_on(rf_cols).finish() + }, + _ => { + polars_bail!(SQLInterface: "'{} {}' is not supported", op_name, quantifier.to_string()) + }, + }; + Ok(joined_tbl.unique(None, UniqueKeepStrategy::Any)) + } + fn process_union( &mut self, left: &SetExpr, @@ -291,32 +364,40 @@ impl SQLContext { quantifier: &SetQuantifier, query: &Query, ) -> PolarsResult { - let left = self.process_set_expr(left, query)?; - let right = self.process_set_expr(right, query)?; + let mut lf = self.process_set_expr(left, query)?; + let mut rf = self.process_set_expr(right, query)?; let opts = UnionArgs { parallel: true, to_supertypes: true, ..Default::default() }; match quantifier { - // UNION ALL - SetQuantifier::All => polars_lazy::dsl::concat(vec![left, right], opts), - // UNION [DISTINCT] - SetQuantifier::Distinct | SetQuantifier::None => { - let concatenated = polars_lazy::dsl::concat(vec![left, right], opts); - concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)) + // UNION [ALL | DISTINCT] + SetQuantifier::All | SetQuantifier::Distinct | SetQuantifier::None => { + let lf_schema = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + let rf_schema = rf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; + if lf_schema.len() != rf_schema.len() { + polars_bail!(SQLInterface: "UNION requires equal number of columns in each table (use 'UNION BY NAME' to combine mismatched tables)") + } + let concatenated = polars_lazy::dsl::concat(vec![lf, rf], opts); + match quantifier { + SetQuantifier::Distinct | SetQuantifier::None => { + concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)) + }, + _ => concatenated, + } }, // UNION ALL BY NAME #[cfg(feature = "diagonal_concat")] - SetQuantifier::AllByName => concat_lf_diagonal(vec![left, right], opts), + SetQuantifier::AllByName => concat_lf_diagonal(vec![lf, rf], opts), // UNION [DISTINCT] BY NAME #[cfg(feature = "diagonal_concat")] SetQuantifier::ByName | SetQuantifier::DistinctByName => { - let concatenated = concat_lf_diagonal(vec![left, right], opts); + let concatenated = concat_lf_diagonal(vec![lf, rf], opts); concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)) }, #[allow(unreachable_patterns)] - _ => polars_bail!(SQLInterface: "'UNION {}' is not yet supported", quantifier), + _ => polars_bail!(SQLInterface: "'UNION {}' is not currently supported", quantifier), } } @@ -466,7 +547,7 @@ impl SQLContext { join_type => { polars_bail!( SQLInterface: - "join type '{:?}' not yet supported by polars-sql", join_type + "join type '{:?}' not currently supported", join_type ); }, }; @@ -763,7 +844,7 @@ impl SQLContext { let tbl_name = name.0.first().unwrap().value.as_str(); // CREATE TABLE IF NOT EXISTS if *if_not_exists && self.table_map.contains_key(tbl_name) { - polars_bail!(SQLInterface: "relation {} already exists", tbl_name); + polars_bail!(SQLInterface: "relation '{}' already exists", tbl_name); // CREATE OR REPLACE TABLE } if let Some(query) = query { @@ -776,7 +857,7 @@ impl SQLContext { .lazy(); Ok(out) } else { - polars_bail!(SQLInterface: "only `CREATE TABLE AS SELECT` is currently supported"); + polars_bail!(SQLInterface: "only `CREATE TABLE AS SELECT ...` is currently supported"); } } else { unreachable!() @@ -881,8 +962,7 @@ impl SQLContext { polars_bail!(SQLSyntax: "UNNEST table must have an alias"); } }, - - // Support bare table, optional with alias for now + // Support bare table, optionally with an alias, for now _ => polars_bail!(SQLInterface: "not yet implemented: {}", relation), } } @@ -909,11 +989,11 @@ impl SQLContext { fn process_order_by( &mut self, mut lf: LazyFrame, - ob: &[OrderByExpr], + order_by: &[OrderByExpr], ) -> PolarsResult { - let mut by = Vec::with_capacity(ob.len()); - let mut descending = Vec::with_capacity(ob.len()); - let mut nulls_last = Vec::with_capacity(ob.len()); + let mut by = Vec::with_capacity(order_by.len()); + let mut descending = Vec::with_capacity(order_by.len()); + let mut nulls_last = Vec::with_capacity(order_by.len()); let schema = Some(lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?); let column_names = schema @@ -923,7 +1003,7 @@ impl SQLContext { .map(|e| col(e)) .collect::>(); - for ob in ob { + for ob in order_by { // note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise // https://www.postgresql.org/docs/current/queries-order.html let desc_order = !ob.asc.unwrap_or(true); @@ -951,7 +1031,7 @@ impl SQLContext { ) -> PolarsResult { polars_ensure!( !contains_wildcard, - SQLSyntax: "GROUP BY error: cannot process wildcard in group_by" + SQLSyntax: "GROUP BY error (cannot process wildcard in group_by)" ); let schema_before = lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena)?; let group_by_keys_schema = @@ -1082,7 +1162,7 @@ impl SQLContext { cols(schema.iter_names()) }, e => polars_bail!( - SQLSyntax: "invalid wildcard expression: {:?}", + SQLSyntax: "invalid wildcard expression ({:?})", e ), }; @@ -1096,7 +1176,7 @@ impl SQLContext { contains_wildcard_exclude: &mut bool, ) -> PolarsResult { if options.opt_except.is_some() { - polars_bail!(SQLSyntax: "EXCEPT not supported; use EXCLUDE instead") + polars_bail!(SQLSyntax: "EXCEPT not supported (use EXCLUDE instead)") } Ok(match &options.opt_exclude { Some(ExcludeSelectItem::Single(ident)) => { diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 74ba07ac7830..e1ac26e98275 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -10,8 +10,9 @@ use polars_plan::prelude::col; use polars_plan::prelude::LiteralValue::Null; use polars_plan::prelude::{lit, StrptimeOptions}; use sqlparser::ast::{ - DateTimeField, Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, Ident, - Value as SQLValue, WindowSpec, WindowType, + DateTimeField, DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, + FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, Ident, + OrderByExpr, Value as SQLValue, WindowSpec, WindowType, }; use crate::sql_expr::{parse_extract_date_part, parse_sql_expr}; @@ -546,6 +547,12 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT unnest(column_1) from df; /// ``` Explode, + /// SQL 'array_agg' function + /// Concatenates the input expressions, including nulls, into an array. + /// ```sql + /// SELECT ARRAY_AGG(column_1, column_2, ...) from df; + /// ``` + ArrayAgg, /// SQL 'array_to_string' function /// Takes all elements of the array and joins them into one string. /// ```sql @@ -770,6 +777,7 @@ impl PolarsSQLFunctions { // ---- // Array functions // ---- + "array_agg" => Self::ArrayAgg, "array_contains" => Self::ArrayContains, "array_get" => Self::ArrayGet, "array_length" => Self::ArrayLength, @@ -795,10 +803,20 @@ impl PolarsSQLFunctions { impl SQLFunctionVisitor<'_> { pub(crate) fn visit_function(&mut self) -> PolarsResult { + use PolarsSQLFunctions::*; + let function_name = PolarsSQLFunctions::try_from_sql(self.func, self.ctx)?; let function = self.func; - let function_name = PolarsSQLFunctions::try_from_sql(function, self.ctx)?; - use PolarsSQLFunctions::*; + // TODO: implement the following functions where possible + if !function.within_group.is_empty() { + polars_bail!(SQLInterface: "'WITHIN GROUP' is not currently supported") + } + if function.filter.is_some() { + polars_bail!(SQLInterface: "'FILTER' is not currently supported") + } + if function.null_treatment.is_some() { + polars_bail!(SQLInterface: "'IGNORE|RESPECT NULLS' is not currently supported") + } match function_name { // ---- @@ -807,7 +825,7 @@ impl SQLFunctionVisitor<'_> { Abs => self.visit_unary(Expr::abs), Cbrt => self.visit_unary(Expr::cbrt), Ceil => self.visit_unary(Expr::ceil), - Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64),), + Div => self.visit_binary(|e, d| e.floor_div(d).cast(DataType::Int64)), Exp => self.visit_unary(Expr::exp), Floor => self.visit_unary(Expr::floor), Ln => self.visit_unary(|e| e.log(std::f64::consts::E)), @@ -818,19 +836,22 @@ impl SQLFunctionVisitor<'_> { Pi => self.visit_nullary(Expr::pi), Mod => self.visit_binary(|e1, e2| e1 % e2), Pow => self.visit_binary::(Expr::pow), - Round => match function.args.len() { - 1 => self.visit_unary(|e| e.round(0)), - 2 => self.try_visit_binary(|e, decimals| { - Ok(e.round(match decimals { - Expr::Literal(LiteralValue::Int(n)) => { - if n >= 0 { n as u32 } else { - polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", function.args[1]) - } - }, - _ => polars_bail!(SQLSyntax: "invalid decimals value for ROUND ({})", function.args[1]), - })) - }), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for ROUND (expected 1-2, found {})", function.args.len()), + Round => { + let args = extract_args(function)?; + match args.len() { + 1 => self.visit_unary(|e| e.round(0)), + 2 => self.try_visit_binary(|e, decimals| { + Ok(e.round(match decimals { + Expr::Literal(LiteralValue::Int(n)) => { + if n >= 0 { n as u32 } else { + polars_bail!(SQLInterface: "ROUND does not currently support negative decimals value ({})", args[1]) + } + }, + _ => polars_bail!(SQLSyntax: "invalid value for ROUND decimals ({})", args[1]), + })) + }), + _ => polars_bail!(SQLSyntax: "invalid number of arguments for ROUND (expected 1-2, found {})", args.len()), + } }, Sign => self.visit_unary(Expr::sign), Sqrt => self.visit_unary(Expr::sqrt), @@ -862,44 +883,71 @@ impl SQLFunctionVisitor<'_> { // ---- Coalesce => self.visit_variadic(coalesce), Greatest => self.visit_variadic(|exprs: &[Expr]| max_horizontal(exprs).unwrap()), - If => match function.args.len() { - 3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| { - Ok(when(cond).then(expr1).otherwise(expr2)) - }), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for IF: {}", function.args.len() - ), + If => { + let args = extract_args(function)?; + match args.len() { + 3 => self.try_visit_ternary(|cond: Expr, expr1: Expr, expr2: Expr| { + Ok(when(cond).then(expr1).otherwise(expr2)) + }), + _ => { + polars_bail!(SQLSyntax: "invalid number of arguments for IF ({})", args.len() + ) + }, + } }, - IfNull => match function.args.len() { - 2 => self.visit_variadic(coalesce), - _ => polars_bail!(SQLSyntax:"Invalid number of arguments for IFNULL: {}", function.args.len()) + IfNull => { + let args = extract_args(function)?; + match args.len() { + 2 => self.visit_variadic(coalesce), + _ => { + polars_bail!(SQLSyntax:"Invalid number of arguments for IFNULL ({})", args.len()) + }, + } }, Least => self.visit_variadic(|exprs: &[Expr]| min_horizontal(exprs).unwrap()), - NullIf => match function.args.len() { - 2 => self.visit_binary(|l: Expr, r: Expr| when(l.clone().eq(r)).then(lit(LiteralValue::Null)).otherwise(l)), - _ => polars_bail!(SQLSyntax:"Invalid number of arguments for NULLIF: {}", function.args.len()) + NullIf => { + let args = extract_args(function)?; + match args.len() { + 2 => self.visit_binary(|l: Expr, r: Expr| { + when(l.clone().eq(r)) + .then(lit(LiteralValue::Null)) + .otherwise(l) + }), + _ => { + polars_bail!(SQLSyntax:"Invalid number of arguments for NULLIF ({})", args.len()) + }, + } }, // ---- // Date functions // ---- - Date => match function.args.len() { - 1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())), - 2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for DATE: {}", function.args.len()), + Date => { + let args = extract_args(function)?; + match args.len() { + 1 => self.visit_unary(|e| e.str().to_date(StrptimeOptions::default())), + 2 => self.visit_binary(|e, fmt| e.str().to_date(fmt)), + _ => { + polars_bail!(SQLSyntax: "invalid number of arguments for DATE ({})", args.len()) + }, + } }, DatePart => self.try_visit_binary(|part, e| { match part { Expr::Literal(LiteralValue::String(p)) => { // note: 'DATE_PART' and 'EXTRACT' are minor syntactic // variations on otherwise identical functionality - parse_extract_date_part(e, &DateTimeField::Custom(Ident { - value: p, - quote_style: None, - })) + parse_extract_date_part( + e, + &DateTimeField::Custom(Ident { + value: p, + quote_style: None, + }), + ) }, _ => { - polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART: {}", function.args[1]); - } + polars_bail!(SQLSyntax: "invalid 'part' for EXTRACT/DATE_PART ({})", part); + }, } }), @@ -907,20 +955,26 @@ impl SQLFunctionVisitor<'_> { // String functions // ---- BitLength => self.visit_unary(|e| e.str().len_bytes() * lit(8)), - Concat => if function.args.is_empty() { - polars_bail!(SQLSyntax: "invalid number of arguments for CONCAT: 0"); - } else { - self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true)) + Concat => { + let args = extract_args(function)?; + if args.is_empty() { + polars_bail!(SQLSyntax: "invalid number of arguments for CONCAT (0)"); + } else { + self.visit_variadic(|exprs: &[Expr]| concat_str(exprs, "", true)) + } }, - ConcatWS => if function.args.len() < 2 { - polars_bail!(SQLSyntax: "invalid number of arguments for CONCAT_WS: {}", function.args.len()); - } else { - self.try_visit_variadic(|exprs: &[Expr]| { - match &exprs[0] { - Expr::Literal(LiteralValue::String(s)) => Ok(concat_str(&exprs[1..], s, true)), - _ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string; found {:?}", exprs[0]), - } - }) + ConcatWS => { + let args = extract_args(function)?; + if args.len() < 2 { + polars_bail!(SQLSyntax: "invalid number of arguments for CONCAT_WS ({})", args.len()); + } else { + self.try_visit_variadic(|exprs: &[Expr]| { + match &exprs[0] { + Expr::Literal(LiteralValue::String(s)) => Ok(concat_str(&exprs[1..], s, true)), + _ => polars_bail!(SQLSyntax: "CONCAT_WS 'separator' must be a literal string (found {:?})", exprs[0]), + } + }) + } }, EndsWith => self.visit_binary(|e, s| e.str().ends_with(s)), #[cfg(feature = "nightly")] @@ -930,53 +984,74 @@ impl SQLFunctionVisitor<'_> { Expr::Literal(Null) => lit(Null), Expr::Literal(LiteralValue::Int(0)) => lit(""), Expr::Literal(LiteralValue::Int(n)) => { - let len = if n > 0 { lit(n) } else { (e.clone().str().len_chars() + lit(n)).clip_min(lit(0)) }; + let len = if n > 0 { + lit(n) + } else { + (e.clone().str().len_chars() + lit(n)).clip_min(lit(0)) + }; e.str().slice(lit(0), len) }, - Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT: {}", function.args[1]), - _ => { - when(length.clone().gt_eq(lit(0))) - .then(e.clone().str().slice(lit(0), length.clone().abs())) - .otherwise(e.clone().str().slice(lit(0), (e.clone().str().len_chars() + length.clone()).clip_min(lit(0)))) - } - } - )}), + Expr::Literal(v) => { + polars_bail!(SQLSyntax: "invalid 'n_chars' for LEFT ({:?})", v) + }, + _ => when(length.clone().gt_eq(lit(0))) + .then(e.clone().str().slice(lit(0), length.clone().abs())) + .otherwise(e.clone().str().slice( + lit(0), + (e.clone().str().len_chars() + length.clone()).clip_min(lit(0)), + )), + }) + }), Length => self.visit_unary(|e| e.str().len_chars()), Lower => self.visit_unary(|e| e.str().to_lowercase()), - LTrim => match function.args.len() { - 1 => self.visit_unary(|e| e.str().strip_chars_start(lit(Null))), - 2 => self.visit_binary(|e, s| e.str().strip_chars_start(s)), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for LTRIM: {}", function.args.len()), + LTrim => { + let args = extract_args(function)?; + match args.len() { + 1 => self.visit_unary(|e| e.str().strip_chars_start(lit(Null))), + 2 => self.visit_binary(|e, s| e.str().strip_chars_start(s)), + _ => { + polars_bail!(SQLSyntax: "invalid number of arguments for LTRIM ({})", args.len()) + }, + } }, OctetLength => self.visit_unary(|e| e.str().len_bytes()), StrPos => { // note: 1-indexed, not 0-indexed, and returns zero if match not found - self.visit_binary(|expr, substring| (expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32))) + self.visit_binary(|expr, substring| { + (expr.str().find(substring, true) + typed_lit(1u32)).fill_null(typed_lit(0u32)) + }) }, - RegexpLike => match function.args.len() { - 2 => self.visit_binary(|e, s| e.str().contains(s, true)), - 3 => self.try_visit_ternary(|e, pat, flags| { - Ok(e.str().contains( - match (pat, flags) { - (Expr::Literal(LiteralValue::String(s)), Expr::Literal(LiteralValue::String(f))) => { - if f.is_empty() { - polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE: {}", function.args[2]); - }; - lit(format!("(?{}){}", f, s)) + RegexpLike => { + let args = extract_args(function)?; + match args.len() { + 2 => self.visit_binary(|e, s| e.str().contains(s, true)), + 3 => self.try_visit_ternary(|e, pat, flags| { + Ok(e.str().contains( + match (pat, flags) { + (Expr::Literal(LiteralValue::String(s)), Expr::Literal(LiteralValue::String(f))) => { + if f.is_empty() { + polars_bail!(SQLSyntax: "invalid/empty 'flags' for REGEXP_LIKE ({})", args[2]); + }; + lit(format!("(?{}){}", f, s)) + }, + _ => { + polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE ({}, {})", args[1], args[2]); + }, }, - _ => { - polars_bail!(SQLSyntax: "invalid arguments for REGEXP_LIKE: {}, {}", function.args[1], function.args[2]); - }, - }, - true)) - }), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for REGEXP_LIKE: {}",function.args.len()), + true)) + }), + _ => polars_bail!(SQLSyntax: "invalid number of arguments for REGEXP_LIKE ({})",args.len()), + } }, - Replace => match function.args.len() { - 3 => self.try_visit_ternary(|e, old, new| { - Ok(e.str().replace_all(old, new, true)) - }), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for REPLACE: {}", function.args.len()), + Replace => { + let args = extract_args(function)?; + match args.len() { + 3 => self + .try_visit_ternary(|e, old, new| Ok(e.str().replace_all(old, new, true))), + _ => { + polars_bail!(SQLSyntax: "invalid number of arguments for REPLACE ({})", args.len()) + }, + } }, Reverse => self.visit_unary(|e| e.str().reverse()), Right => self.try_visit_binary(|e, length| { @@ -985,58 +1060,73 @@ impl SQLFunctionVisitor<'_> { Expr::Literal(LiteralValue::Int(0)) => typed_lit(""), Expr::Literal(LiteralValue::Int(n)) => { let n: i64 = n.try_into().unwrap(); - let offset = if n < 0 { lit(n.abs()) } else { e.clone().str().len_chars().cast(DataType::Int32) - lit(n) }; + let offset = if n < 0 { + lit(n.abs()) + } else { + e.clone().str().len_chars().cast(DataType::Int32) - lit(n) + }; e.str().slice(offset, lit(Null)) }, - Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT: {}", function.args[1]), + Expr::Literal(v) => { + polars_bail!(SQLSyntax: "invalid 'n_chars' for RIGHT ({:?})", v) + }, + _ => when(length.clone().lt(lit(0))) + .then(e.clone().str().slice(length.clone().abs(), lit(Null))) + .otherwise(e.clone().str().slice( + e.clone().str().len_chars().cast(DataType::Int32) - length.clone(), + lit(Null), + )), + }) + }), + RTrim => { + let args = extract_args(function)?; + match args.len() { + 1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))), + 2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)), _ => { - when(length.clone().lt(lit(0))) - .then(e.clone().str().slice(length.clone().abs(), lit(Null))) - .otherwise(e.clone().str().slice(e.clone().str().len_chars().cast(DataType::Int32) - length.clone(), lit(Null))) - } + polars_bail!(SQLSyntax: "invalid number of arguments for RTRIM ({})", args.len()) + }, } - )}), - RTrim => match function.args.len() { - 1 => self.visit_unary(|e| e.str().strip_chars_end(lit(Null))), - 2 => self.visit_binary(|e, s| e.str().strip_chars_end(s)), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for RTRIM: {}", function.args.len()), }, StartsWith => self.visit_binary(|e, s| e.str().starts_with(s)), - Substring => match function.args.len() { - // note that SQL is 1-indexed, not 0-indexed, hence the need for adjustments - 2 => self.try_visit_binary(|e, start| { - Ok(match start { - Expr::Literal(Null) => lit(Null), - Expr::Literal(LiteralValue::Int(n)) if n <= 0 => e, - Expr::Literal(LiteralValue::Int(n)) => e.str().slice(lit(n - 1), lit(Null)), - Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR: {}", function.args[1]), - _ => start.clone() + lit(1), - }) - }), - 3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| { - Ok(match (start.clone(), length.clone()) { - (Expr::Literal(Null), _) | (_, Expr::Literal(Null)) => lit(Null), - (_, Expr::Literal(LiteralValue::Int(n))) if n < 0 => { - polars_bail!(SQLSyntax: "SUBSTR does not support negative length: {}", function.args[2]) - }, - (Expr::Literal(LiteralValue::Int(n)), _) if n > 0 => e.str().slice(lit(n - 1), length.clone()), - (Expr::Literal(LiteralValue::Int(n)), _) => { - e.str().slice(lit(0), (length.clone() + lit(n - 1)).clip_min(lit(0))) - }, - (Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR: {}", function.args[1]), - (_, Expr::Literal(LiteralValue::Float(_))) => { - polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR: {}", function.args[1]) - }, - _ => { - let adjusted_start = start.clone() - lit(1); - when(adjusted_start.clone().lt(lit(0))) - .then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0)))) - .otherwise(e.clone().str().slice(adjusted_start.clone(), length.clone())) - } - }) - }), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for SUBSTR: {}", function.args.len()), - } + Substring => { + let args = extract_args(function)?; + match args.len() { + // note that SQL is 1-indexed, not 0-indexed, hence the need for adjustments + 2 => self.try_visit_binary(|e, start| { + Ok(match start { + Expr::Literal(Null) => lit(Null), + Expr::Literal(LiteralValue::Int(n)) if n <= 0 => e, + Expr::Literal(LiteralValue::Int(n)) => e.str().slice(lit(n - 1), lit(Null)), + Expr::Literal(_) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]), + _ => start.clone() + lit(1), + }) + }), + 3 => self.try_visit_ternary(|e: Expr, start: Expr, length: Expr| { + Ok(match (start.clone(), length.clone()) { + (Expr::Literal(Null), _) | (_, Expr::Literal(Null)) => lit(Null), + (_, Expr::Literal(LiteralValue::Int(n))) if n < 0 => { + polars_bail!(SQLSyntax: "SUBSTR does not support negative length ({})", args[2]) + }, + (Expr::Literal(LiteralValue::Int(n)), _) if n > 0 => e.str().slice(lit(n - 1), length.clone()), + (Expr::Literal(LiteralValue::Int(n)), _) => { + e.str().slice(lit(0), (length.clone() + lit(n - 1)).clip_min(lit(0))) + }, + (Expr::Literal(_), _) => polars_bail!(SQLSyntax: "invalid 'start' for SUBSTR ({})", args[1]), + (_, Expr::Literal(LiteralValue::Float(_))) => { + polars_bail!(SQLSyntax: "invalid 'length' for SUBSTR ({})", args[1]) + }, + _ => { + let adjusted_start = start.clone() - lit(1); + when(adjusted_start.clone().lt(lit(0))) + .then(e.clone().str().slice(lit(0), (length.clone() + adjusted_start.clone()).clip_min(lit(0)))) + .otherwise(e.clone().str().slice(adjusted_start.clone(), length.clone())) + } + }) + }), + _ => polars_bail!(SQLSyntax: "invalid number of arguments for SUBSTR ({})", args.len()), + } + }, Upper => self.visit_unary(|e| e.str().to_uppercase()), // ---- @@ -1056,6 +1146,7 @@ impl SQLFunctionVisitor<'_> { // ---- // Array functions // ---- + ArrayAgg => self.visit_arr_agg(), ArrayContains => self.visit_binary::(|e, s| e.list().contains(s)), ArrayGet => self.visit_binary(|e, i| e.list().get(i, true)), ArrayLength => self.visit_unary(|e| e.list().len()), @@ -1064,31 +1155,15 @@ impl SQLFunctionVisitor<'_> { ArrayMin => self.visit_unary(|e| e.list().min()), ArrayReverse => self.visit_unary(|e| e.list().reverse()), ArraySum => self.visit_unary(|e| e.list().sum()), - ArrayToString => match function.args.len() { - 2 => self.try_visit_binary(|e, sep| { Ok(e.list().join(sep, true)) }), - #[cfg(feature = "list_eval")] - 3 => self.try_visit_ternary(|e, sep, null_value| { - match null_value { - Expr::Literal(LiteralValue::String(v)) => { - Ok(if v.is_empty() { - e.list().join(sep, true) - } else { - e.list().eval(col("").fill_null(lit(v)), false).list().join(sep, false) - }) - }, - _ => polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING: {}", function.args[2]), - } - }), - _ => polars_bail!(SQLSyntax: "invalid number of arguments for ARRAY_TO_STRING: {}", function.args.len()), - } + ArrayToString => self.visit_arr_to_string(), ArrayUnique => self.visit_unary(|e| e.list().unique()), Explode => self.visit_unary(|e| e.explode()), - Udf(func_name) => self.visit_udf(&func_name) + Udf(func_name) => self.visit_udf(&func_name), } } fn visit_udf(&mut self, func_name: &str) -> PolarsResult { - let args = extract_args(self.func) + let args = extract_args(self.func)? .into_iter() .map(|arg| { if let FunctionArgExpr::Expr(e) = arg { @@ -1172,7 +1247,7 @@ impl SQLFunctionVisitor<'_> { } fn visit_unary_no_window(&mut self, f: impl Fn(Expr) -> Expr) -> PolarsResult { - let args = extract_args(self.func); + let args = extract_args(self.func)?; match args.as_slice() { [FunctionArgExpr::Expr(sql_expr)] => { let expr = parse_sql_expr(sql_expr, self.ctx, None)?; @@ -1194,7 +1269,7 @@ impl SQLFunctionVisitor<'_> { &mut self, f: impl Fn(Expr, Arg) -> PolarsResult, ) -> PolarsResult { - let args = extract_args(self.func); + let args = extract_args(self.func)?; match args.as_slice() { [FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2)] => { let expr1 = parse_sql_expr(sql_expr1, self.ctx, None)?; @@ -1213,7 +1288,7 @@ impl SQLFunctionVisitor<'_> { &mut self, f: impl Fn(&[Expr]) -> PolarsResult, ) -> PolarsResult { - let args = extract_args(self.func); + let args = extract_args(self.func)?; let mut expr_args = vec![]; for arg in args { if let FunctionArgExpr::Expr(sql_expr) = arg { @@ -1229,7 +1304,7 @@ impl SQLFunctionVisitor<'_> { &mut self, f: impl Fn(Expr, Arg, Arg) -> PolarsResult, ) -> PolarsResult { - let args = extract_args(self.func); + let args = extract_args(self.func)?; match args.as_slice() { [FunctionArgExpr::Expr(sql_expr1), FunctionArgExpr::Expr(sql_expr2), FunctionArgExpr::Expr(sql_expr3)] => { @@ -1243,16 +1318,82 @@ impl SQLFunctionVisitor<'_> { } fn visit_nullary(&self, f: impl Fn() -> Expr) -> PolarsResult { - let args = extract_args(self.func); + let args = extract_args(self.func)?; if !args.is_empty() { return self.not_supported_error(); } Ok(f()) } + fn visit_arr_agg(&mut self) -> PolarsResult { + let (args, is_distinct, clauses) = extract_args_and_clauses(self.func)?; + match args.as_slice() { + [FunctionArgExpr::Expr(sql_expr)] => { + let mut base = parse_sql_expr(sql_expr, self.ctx, None)?; + if is_distinct { + base = base.unique_stable(); + } + for clause in clauses { + match clause { + FunctionArgumentClause::OrderBy(order_exprs) => { + base = self.apply_order_by(base, order_exprs.as_slice())?; + }, + FunctionArgumentClause::Limit(limit_expr) => { + let limit = parse_sql_expr(&limit_expr, self.ctx, None)?; + match limit { + Expr::Literal(LiteralValue::Int(n)) if n >= 0 => { + base = base.head(Some(n as usize)) + }, + _ => { + polars_bail!(SQLSyntax: "LIMIT in ARRAY_AGG must be a positive integer") + }, + }; + }, + _ => {}, + } + } + Ok(base.implode()) + }, + _ => { + polars_bail!(SQLSyntax: "ARRAY_AGG must have exactly one argument; found {}", args.len()) + }, + } + } + + fn visit_arr_to_string(&mut self) -> PolarsResult { + let args = extract_args(self.func)?; + match args.len() { + 2 => self.try_visit_binary(|e, sep| { + Ok(e.cast(DataType::List(Box::from(DataType::String))) + .list() + .join(sep, true)) + }), + #[cfg(feature = "list_eval")] + 3 => self.try_visit_ternary(|e, sep, null_value| match null_value { + Expr::Literal(LiteralValue::String(v)) => Ok(if v.is_empty() { + e.cast(DataType::List(Box::from(DataType::String))) + .list() + .join(sep, true) + } else { + e.cast(DataType::List(Box::from(DataType::String))) + .list() + .eval(col("").fill_null(lit(v)), false) + .list() + .join(sep, false) + }), + _ => { + polars_bail!(SQLSyntax: "invalid null value for ARRAY_TO_STRING ({})", args[2]) + }, + }), + _ => { + polars_bail!(SQLSyntax: "invalid number of arguments for ARRAY_TO_STRING ({})", args.len()) + }, + } + } + fn visit_count(&mut self) -> PolarsResult { - let args = extract_args(self.func); - match (self.func.distinct, args.as_slice()) { + let (args, is_distinct) = extract_args_distinct(self.func)?; + match (is_distinct, args.as_slice()) { // count(*), count() (false, [FunctionArgExpr::Wildcard] | []) => Ok(len()), // count(column_name) @@ -1271,6 +1412,28 @@ impl SQLFunctionVisitor<'_> { } } + fn apply_order_by(&mut self, expr: Expr, order_by: &[OrderByExpr]) -> PolarsResult { + let mut by = Vec::with_capacity(order_by.len()); + let mut descending = Vec::with_capacity(order_by.len()); + let mut nulls_last = Vec::with_capacity(order_by.len()); + + for ob in order_by { + // note: if not specified 'NULLS FIRST' is default for DESC, 'NULLS LAST' otherwise + // https://www.postgresql.org/docs/current/queries-order.html + let desc_order = !ob.asc.unwrap_or(true); + by.push(parse_sql_expr(&ob.expr, self.ctx, None)?); + nulls_last.push(!ob.nulls_first.unwrap_or(desc_order)); + descending.push(desc_order); + } + Ok(expr.sort_by( + by, + SortMultipleOptions::default() + .with_order_descending_multi(descending) + .with_nulls_last_multi(nulls_last) + .with_maintain_order(true), + )) + } + fn apply_window_spec( &mut self, expr: Expr, @@ -1317,15 +1480,54 @@ impl SQLFunctionVisitor<'_> { } } -fn extract_args(sql_function: &SQLFunction) -> Vec<&FunctionArgExpr> { - sql_function - .args - .iter() - .map(|arg| match arg { - FunctionArg::Named { arg, .. } => arg, - FunctionArg::Unnamed(arg) => arg, - }) - .collect() +fn extract_args(func: &SQLFunction) -> PolarsResult> { + let (args, _, _) = _extract_func_args(func, false, false)?; + Ok(args) +} + +fn extract_args_distinct(func: &SQLFunction) -> PolarsResult<(Vec<&FunctionArgExpr>, bool)> { + let (args, is_distinct, _) = _extract_func_args(func, true, false)?; + Ok((args, is_distinct)) +} + +fn extract_args_and_clauses( + func: &SQLFunction, +) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec)> { + _extract_func_args(func, true, true) +} + +fn _extract_func_args( + func: &SQLFunction, + get_distinct: bool, + get_clauses: bool, +) -> PolarsResult<(Vec<&FunctionArgExpr>, bool, Vec)> { + match &func.args { + FunctionArguments::List(FunctionArgumentList { + args, + duplicate_treatment, + clauses, + }) => { + let is_distinct = matches!(duplicate_treatment, Some(DuplicateTreatment::Distinct)); + if !(get_clauses || get_distinct) && is_distinct { + polars_bail!(SQLSyntax: "unexpected use of DISTINCT found in '{}'", func.name) + } else if !get_clauses && !clauses.is_empty() { + polars_bail!(SQLSyntax: "unexpected clause found in '{}' ({})", func.name, clauses[0]) + } else { + let unpacked_args = args + .iter() + .map(|arg| match arg { + FunctionArg::Named { arg, .. } => arg, + FunctionArg::Unnamed(arg) => arg, + }) + .collect(); + Ok((unpacked_args, is_distinct, clauses.clone())) + } + }, + FunctionArguments::Subquery { .. } => { + Err(polars_err!(SQLInterface: "subquery not expected in {}", func.name)) + }, + FunctionArguments::None => Ok((vec![], false, vec![])), + } } pub(crate) trait FromSQLExpr { diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 227b1d7fa4fd..a8f46b0d9e13 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -15,9 +15,9 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "dtype-decimal")] use sqlparser::ast::ExactNumberInfo; use sqlparser::ast::{ - ArrayAgg, ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, + ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Function as SQLFunction, Ident, - Interval, JoinConstraint, ObjectName, OrderByExpr, Query as Subquery, SelectItem, TimezoneInfo, + Interval, JoinConstraint, ObjectName, Query as Subquery, SelectItem, TimezoneInfo, TrimWhereField, UnaryOperator, Value as SQLValue, }; use sqlparser::dialect::GenericDialect; @@ -43,7 +43,7 @@ fn timeunit_from_precision(prec: &Option) -> PolarsResult { Some(n) if (4u64..=6u64).contains(n) => TimeUnit::Microseconds, Some(n) if (7u64..=9u64).contains(n) => TimeUnit::Nanoseconds, Some(n) => { - polars_bail!(SQLSyntax: "invalid temporal type precision; expected 1-9, found {}", n) + polars_bail!(SQLSyntax: "invalid temporal type precision (expected 1-9, found {})", n) }, }) } @@ -54,7 +54,7 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { + | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type, _)) => { DataType::List(Box::new(map_sql_polars_datatype(inner_type)?)) }, @@ -101,7 +101,7 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult DataType::Float32, Some(n) if (25u64..=53u64).contains(n) => DataType::Float64, Some(n) => { - polars_bail!(SQLSyntax: "unsupported `float` size; expected a value between 1 and 53, found {}", n) + polars_bail!(SQLSyntax: "unsupported `float` size (expected a value between 1 and 53, found {})", n) }, None => DataType::Float64, }, @@ -204,15 +204,15 @@ impl SQLExprVisitor<'_> { SQLExpr::Value(v) => self.visit_any_value(v, None), SQLExpr::UnaryOp { op, expr } => match expr.as_ref() { SQLExpr::Value(v) => self.visit_any_value(v, Some(op)), - _ => Err(polars_err!(SQLInterface: "expression {:?} is not yet supported", e)), + _ => Err(polars_err!(SQLInterface: "expression {:?} is not currently supported", e)), }, SQLExpr::Array(_) => { // TODO: nested arrays (handle FnMut issues) // let srs = self.array_expr_to_series(&[e.clone()])?; // Ok(AnyValue::List(srs)) - Err(polars_err!(SQLInterface: "nested array literals are not yet supported:\n{:?}", e)) + Err(polars_err!(SQLInterface: "nested array literals are not currently supported:\n{:?}", e)) }, - _ => Err(polars_err!(SQLInterface: "expression {:?} is not yet supported", e)), + _ => Err(polars_err!(SQLInterface: "expression {:?} is not currently supported", e)), }) .collect::>>()?; @@ -232,7 +232,6 @@ impl SQLExprVisitor<'_> { right, } => self.visit_any(left, compare_op, right), SQLExpr::Array(arr) => self.visit_array_expr(&arr.elem, true, None), - SQLExpr::ArrayAgg(expr) => self.visit_arr_agg(expr), SQLExpr::Between { expr, negated, @@ -241,10 +240,11 @@ impl SQLExprVisitor<'_> { } => self.visit_between(expr, *negated, low, high), SQLExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right), SQLExpr::Cast { + kind, expr, data_type, format, - } => self.visit_cast(expr, data_type, format, true), + } => self.visit_cast(expr, data_type, format, kind), SQLExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()), SQLExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents), SQLExpr::Extract { field, expr } => { @@ -323,16 +323,11 @@ impl SQLExprVisitor<'_> { trim_what, trim_characters, } => self.visit_trim(expr, trim_where, trim_what, trim_characters), - SQLExpr::TryCast { - expr, - data_type, - format, - } => self.visit_cast(expr, data_type, format, false), SQLExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr), SQLExpr::Value(value) => self.visit_literal(value), e @ SQLExpr::Case { .. } => self.visit_case_when_then(e), other => { - polars_bail!(SQLInterface: "expression {:?} is not yet supported", other) + polars_bail!(SQLInterface: "expression {:?} is not currently supported", other) }, } } @@ -343,9 +338,8 @@ impl SQLExprVisitor<'_> { restriction: SubqueryRestriction, ) -> PolarsResult { if subquery.with.is_some() { - polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE `with` clause"); + polars_bail!(SQLSyntax: "SQL subquery cannot be a CTE 'WITH' clause"); } - let mut lf = self.ctx.execute_query_no_ctes(subquery)?; let schema = lf.schema_with_arenas(&mut self.ctx.lp_arena, &mut self.ctx.expr_arena)?; @@ -425,7 +419,7 @@ impl SQLExprVisitor<'_> { || interval.leading_precision.is_some() || interval.fractional_seconds_precision.is_some() { - polars_bail!(SQLSyntax: "unsupported interval syntax: '{}'", interval) + polars_bail!(SQLSyntax: "unsupported interval syntax ('{}')", interval) } let s = match &*interval.value { SQLExpr::UnaryOp { .. } => { @@ -448,11 +442,11 @@ impl SQLExprVisitor<'_> { negated: bool, expr: &SQLExpr, pattern: &SQLExpr, - escape_char: &Option, + escape_char: &Option, case_insensitive: bool, ) -> PolarsResult { if escape_char.is_some() { - polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not yet supported; found '{}'", escape_char.unwrap()); + polars_bail!(SQLInterface: "ESCAPE char for LIKE/ILIKE is not currently supported; found '{}'", escape_char.clone().unwrap()); } let pat = match self.visit_expr(pattern) { Ok(Expr::Literal(LiteralValue::String(s))) => s, @@ -602,7 +596,7 @@ impl SQLExprVisitor<'_> { }, }, other => { - polars_bail!(SQLInterface: "SQL operator {:?} is not yet supported", other) + polars_bail!(SQLInterface: "SQL operator {:?} is not currently supported", other) }, }) } @@ -725,7 +719,7 @@ impl SQLExprVisitor<'_> { expr: &SQLExpr, data_type: &SQLDataType, format: &Option, - strict: bool, + cast_kind: &CastKind, ) -> PolarsResult { if format.is_some() { return Err( @@ -739,10 +733,9 @@ impl SQLExprVisitor<'_> { return Ok(expr.str().json_decode(None, None)); } let polars_type = map_sql_polars_datatype(data_type)?; - Ok(if strict { - expr.strict_cast(polars_type) - } else { - expr.cast(polars_type) + Ok(match cast_kind { + CastKind::Cast | CastKind::DoubleColon => expr.strict_cast(polars_type), + CastKind::TryCast | CastKind::SafeCast => expr.cast(polars_type), }) } @@ -833,7 +826,7 @@ impl SQLExprVisitor<'_> { SQLValue::SingleQuotedString(s) | SQLValue::DoubleQuotedString(s) => { AnyValue::StringOwned(s.into()) }, - other => polars_bail!(SQLInterface: "SQL value {:?} is not yet supported", other), + other => polars_bail!(SQLInterface: "SQL value {:?} is not currently supported", other), }) } @@ -889,33 +882,6 @@ impl SQLExprVisitor<'_> { }) } - /// Visit a SQL `ARRAY_AGG` expression. - fn visit_arr_agg(&mut self, expr: &ArrayAgg) -> PolarsResult { - let mut base = self.visit_expr(&expr.expr)?; - if let Some(order_by) = expr.order_by.as_ref() { - let (order_by, descending) = self.visit_order_by(order_by)?; - base = base.sort_by( - order_by, - SortMultipleOptions::default().with_order_descending_multi(descending), - ); - } - if let Some(limit) = &expr.limit { - let limit = match self.visit_expr(limit)? { - Expr::Literal(LiteralValue::Int(n)) if n >= 0 => n as usize, - _ => polars_bail!(SQLSyntax: "limit in ARRAY_AGG must be a positive integer"), - }; - base = base.head(Some(limit)); - } - if expr.distinct { - base = base.unique_stable(); - } - polars_ensure!( - !expr.within_group, - SQLInterface: "ARRAY_AGG WITHIN GROUP is not yet supported" - ); - Ok(base.implode()) - } - /// Visit a SQL subquery inside and `IN` expression. fn visit_in_subquery( &mut self, @@ -932,20 +898,6 @@ impl SQLExprVisitor<'_> { } } - /// Visit a SQL `ORDER BY` expression. - fn visit_order_by(&mut self, order_by: &[OrderByExpr]) -> PolarsResult<(Vec, Vec)> { - let mut expr = Vec::with_capacity(order_by.len()); - let mut descending = Vec::with_capacity(order_by.len()); - for order_by_expr in order_by { - let e = self.visit_expr(&order_by_expr.expr)?; - expr.push(e); - let desc = order_by_expr.asc.unwrap_or(false); - descending.push(desc); - } - - Ok((expr, descending)) - } - /// Visit `CASE` control flow expression. fn visit_case_when_then(&mut self, expr: &SQLExpr) -> PolarsResult { if let SQLExpr::Case { @@ -1024,7 +976,7 @@ impl SQLExprVisitor<'_> { } fn err(&self, expr: &Expr) -> PolarsResult { - polars_bail!(SQLInterface: "expression {:?} is not yet supported", expr); + polars_bail!(SQLInterface: "expression {:?} is not currently supported", expr); } } diff --git a/crates/polars-sql/tests/statements.rs b/crates/polars-sql/tests/statements.rs index 922603aeac79..0af4ae64fa86 100644 --- a/crates/polars-sql/tests/statements.rs +++ b/crates/polars-sql/tests/statements.rs @@ -562,6 +562,9 @@ fn test_join_utf8() { ); } +#[test] +fn test_table() {} + #[test] #[should_panic] fn test_compound_invalid_1() { diff --git a/py-polars/docs/source/reference/sql/functions/array.rst b/py-polars/docs/source/reference/sql/functions/array.rst index 01674937fa79..9bf76d0dc374 100644 --- a/py-polars/docs/source/reference/sql/functions/array.rst +++ b/py-polars/docs/source/reference/sql/functions/array.rst @@ -7,6 +7,8 @@ Array * - Function - Description + * - :ref:`ARRAY_AGG ` + - Aggregate a column/expression values as an array. * - :ref:`ARRAY_CONTAINS ` - Returns true if the array contains the value. * - :ref:`ARRAY_GET ` @@ -30,6 +32,35 @@ Array * - :ref:`UNNEST ` - Unnests (explodes) an array column into multiple rows. + +.. _array_agg: + +ARRAY_AGG +--------- +Aggregate a column/expression as an array (equivalent to `implode`). + +Supports optional inline `ORDER BY` and `LIMIT` clauses. + +**Example:** + +.. code-block:: python + + df = pl.DataFrame({"foo": [1, 2, 3], "bar": [4, 5, 6]}) + df.sql(""" + SELECT + ARRAY_AGG(foo ORDER BY foo DESC) AS arr_foo, + ARRAY_AGG(bar LIMIT 2) AS arr_bar + FROM self + """) + # shape: (1, 2) + # ┌───────────┬───────────┐ + # │ arr_foo ┆ arr_bar │ + # │ --- ┆ --- │ + # │ list[i64] ┆ list[i64] │ + # ╞═══════════╪═══════════╡ + # │ [3, 2, 1] ┆ [4, 5] │ + # └───────────┴───────────┘ + .. _array_contains: ARRAY_CONTAINS @@ -42,17 +73,17 @@ Returns true if the array contains the value. df = pl.DataFrame({"foo": [[1, 2], [4, 3]]}) df.sql(""" - SELECT ARRAY_CONTAINS(foo, 2) FROM self + SELECT foo, ARRAY_CONTAINS(foo, 2) AS has_two FROM self """) - # shape: (2, 1) - # ┌───────┐ - # │ foo │ - # │ --- │ - # │ bool │ - # ╞═══════╡ - # │ true │ - # │ false │ - # └───────┘ + # shape: (2, 2) + # ┌───────────┬─────────┐ + # │ foo ┆ has_two │ + # │ --- ┆ --- │ + # │ list[i64] ┆ bool │ + # ╞═══════════╪═════════╡ + # │ [1, 2] ┆ true │ + # │ [4, 3] ┆ false │ + # └───────────┴─────────┘ .. _array_get: @@ -71,17 +102,21 @@ Returns the value at the given index in the array. } ) df.sql(""" - SELECT ARRAY_GET(foo, 1), ARRAY_GET(bar, 2) FROM self + SELECT + foo, bar, + ARRAY_GET(foo, 1) AS foo_at_1, + ARRAY_GET(bar, 2) AS bar_at_2 + FROM self """) - # shape: (2, 2) - # ┌─────┬──────┐ - # │ foo ┆ bar │ - # │ --- ┆ --- │ - # │ i64 ┆ i64 │ - # ╞═════╪══════╡ - # │ 2 ┆ null │ - # │ 3 ┆ 10 │ - # └─────┴──────┘ + # shape: (2, 4) + # ┌───────────┬────────────┬──────────┬──────────┐ + # │ foo ┆ bar ┆ foo_at_1 ┆ bar_at_2 │ + # │ --- ┆ --- ┆ --- ┆ --- │ + # │ list[i64] ┆ list[i64] ┆ i64 ┆ i64 │ + # ╞═══════════╪════════════╪══════════╪══════════╡ + # │ [1, 2] ┆ [6, 7] ┆ 2 ┆ null │ + # │ [4, 3, 2] ┆ [8, 9, 10] ┆ 3 ┆ 10 │ + # └───────────┴────────────┴──────────┴──────────┘ .. _array_length: @@ -95,17 +130,17 @@ Returns the length of the array. df = pl.DataFrame({"foo": [[1, 2], [4, 3, 2]]}) df.sql(""" - SELECT ARRAY_LENGTH(foo) FROM self + SELECT foo, ARRAY_LENGTH(foo) AS n_elems FROM self """) - # shape: (2, 1) - # ┌─────┐ - # │ foo │ - # │ --- │ - # │ u32 │ - # ╞═════╡ - # │ 2 │ - # │ 3 │ - # └─────┘ + # shape: (2, 2) + # ┌───────────┬─────────┐ + # │ foo ┆ n_elems │ + # │ --- ┆ --- │ + # │ list[i64] ┆ u32 │ + # ╞═══════════╪═════════╡ + # │ [1, 2] ┆ 2 │ + # │ [4, 3, 2] ┆ 3 │ + # └───────────┴─────────┘ .. _array_lower: @@ -117,24 +152,19 @@ Returns the lower bound (min value) in an array. .. code-block:: python - df = pl.DataFrame( - { - "foo": [[1, 2], [4, 3, -2]], - "bar": [[6, 7], [8, 9, 10]] - } - ) + df = pl.DataFrame({"foo": [[1, 2], [4, -2, 8]]}) df.sql(""" - SELECT ARRAY_LOWER(foo), ARRAY_LOWER(bar) FROM self + SELECT foo, ARRAY_LOWER(foo) AS min_elem FROM self """) # shape: (2, 2) - # ┌─────┬─────┐ - # │ foo ┆ bar │ - # │ --- ┆ --- │ - # │ i64 ┆ i64 │ - # ╞═════╪═════╡ - # │ 1 ┆ 6 │ - # │ -2 ┆ 8 │ - # └─────┴─────┘ + # ┌────────────┬──────────┐ + # │ foo ┆ min_elem │ + # │ --- ┆ --- │ + # │ list[i64] ┆ i64 │ + # ╞════════════╪══════════╡ + # │ [1, 2] ┆ 1 │ + # │ [4, -2, 8] ┆ -2 │ + # └────────────┴──────────┘ .. _array_mean: @@ -146,28 +176,19 @@ Returns the mean of all values in an array. .. code-block:: python - df = pl.DataFrame( - { - "foo": [[1, 2], [4, 3, -1]], - "bar": [[6, 7], [8, 9, 10]] - } - ) + df = pl.DataFrame({"foo": [[1, 2], [4, 3, -1]]}) df.sql(""" - SELECT - ARRAY_MEAN(foo) AS foo_mean, - ARRAY_MEAN(bar) AS bar_mean - FROM self + SELECT foo, ARRAY_MEAN(foo) AS foo_mean FROM self """) - # shape: (2, 2) - # ┌──────────┬──────────┐ - # │ foo_mean ┆ bar_mean │ - # │ --- ┆ --- │ - # │ f64 ┆ f64 │ - # ╞══════════╪══════════╡ - # │ 1.5 ┆ 6.5 │ - # │ 2.0 ┆ 9.0 │ - # └──────────┴──────────┘ + # ┌────────────┬──────────┐ + # │ foo ┆ foo_mean │ + # │ --- ┆ --- │ + # │ list[i64] ┆ f64 │ + # ╞════════════╪══════════╡ + # │ [1, 2] ┆ 1.5 │ + # │ [4, 3, -1] ┆ 2.0 │ + # └────────────┴──────────┘ .. _array_reverse: @@ -187,20 +208,20 @@ Returns the array with the elements in reverse order. ) df.sql(""" SELECT + foo, ARRAY_REVERSE(foo) AS oof, ARRAY_REVERSE(bar) AS rab FROM self """) - - # shape: (2, 2) - # ┌───────────┬────────────┐ - # │ oof ┆ rab │ - # │ --- ┆ --- │ - # │ list[i64] ┆ list[i64] │ - # ╞═══════════╪════════════╡ - # │ [2, 1] ┆ [7, 6] │ - # │ [2, 3, 4] ┆ [10, 9, 8] │ - # └───────────┴────────────┘ + # shape: (2, 3) + # ┌───────────┬───────────┬────────────┐ + # │ foo ┆ oof ┆ rab │ + # │ --- ┆ --- ┆ --- │ + # │ list[i64] ┆ list[i64] ┆ list[i64] │ + # ╞═══════════╪═══════════╪════════════╡ + # │ [1, 2] ┆ [2, 1] ┆ [7, 6] │ + # │ [4, 3, 2] ┆ [2, 3, 4] ┆ [10, 9, 8] │ + # └───────────┴───────────┴────────────┘ .. _array_sum: @@ -212,28 +233,22 @@ Returns the sum of all values in an array. .. code-block:: python - df = pl.DataFrame( - { - "foo": [[1, -2], [-4, 3, -2]], - "bar": [[-6, 7], [8, -9, 10]] - } - ) + df = pl.DataFrame({"foo": [[1, -2], [10, 3, -2]]}) df.sql(""" SELECT - ARRAY_SUM(foo) AS foo_sum, - ARRAY_SUM(bar) AS bar_sum + foo, + ARRAY_SUM(foo) AS foo_sum FROM self """) - # shape: (2, 2) - # ┌─────────┬─────────┐ - # │ foo_sum ┆ bar_sum │ - # │ --- ┆ --- │ - # │ i64 ┆ i64 │ - # ╞═════════╪═════════╡ - # │ -1 ┆ 1 │ - # │ -3 ┆ 9 │ - # └─────────┴─────────┘ + # ┌─────────────┬─────────┐ + # │ foo ┆ foo_sum │ + # │ --- ┆ --- │ + # │ list[i64] ┆ i64 │ + # ╞═════════════╪═════════╡ + # │ [1, -2] ┆ -1 │ + # │ [10, 3, -2] ┆ 11 │ + # └─────────────┴─────────┘ .. _array_to_string: @@ -245,19 +260,27 @@ Takes all elements of the array and joins them into one string. .. code-block:: python - df = pl.DataFrame({"foo": [["a", "b"], ["c", "d", "e"]]}) + df = pl.DataFrame( + { + "foo": [["a", "b"], ["c", "d", "e"]], + "bar": [[8, None, 8], [3, 2, 1, 0]], + } + ) df.sql(""" - SELECT ARRAY_TO_STRING(foo,',') AS foo_str FROM self + SELECT + ARRAY_TO_STRING(foo,':') AS s_foo, + ARRAY_TO_STRING(bar,':') AS s_bar + FROM self """) - # shape: (2, 1) - # ┌─────────┐ - # │ foo_str │ - # │ --- │ - # │ str │ - # ╞═════════╡ - # │ a,b │ - # │ c,d,e │ - # └─────────┘ + # shape: (2, 2) + # ┌───────┬─────────┐ + # │ s_foo ┆ s_bar │ + # │ --- ┆ --- │ + # │ str ┆ str │ + # ╞═══════╪═════════╡ + # │ a:b ┆ 8:8 │ + # │ c:d:e ┆ 3:2:1:0 │ + # └───────┴─────────┘ .. _array_unique: @@ -293,28 +316,19 @@ Returns the upper bound (max value) in an array. .. code-block:: python - df = pl.DataFrame( - { - "foo": [[1, 2], [4, 3, -2]], - "bar": [[6, 7], [8, 9, 10]] - } - ) + df = pl.DataFrame({"foo": [[5, 0], [4, 8, -2]]}) df.sql(""" - SELECT - ARRAY_UPPER(foo) AS foo_max, - ARRAY_UPPER(bar) AS bar_max - FROM self + SELECT foo, ARRAY_UPPER(foo) AS max_elem FROM self """) - # shape: (2, 2) - # ┌─────────┬─────────┐ - # │ foo_max ┆ bar_max │ - # │ --- ┆ --- │ - # │ i64 ┆ i64 │ - # ╞═════════╪═════════╡ - # │ 2 ┆ 7 │ - # │ 4 ┆ 10 │ - # └─────────┴─────────┘ + # ┌────────────┬──────────┐ + # │ foo ┆ max_elem │ + # │ --- ┆ --- │ + # │ list[i64] ┆ i64 │ + # ╞════════════╪══════════╡ + # │ [5, 0] ┆ 5 │ + # │ [4, 8, -2] ┆ 8 │ + # └────────────┴──────────┘ .. _unnest: @@ -338,7 +352,6 @@ Unnest/explode an array column into multiple rows. UNNEST(bar) AS b FROM self """) - # shape: (5, 2) # ┌─────┬─────┐ # │ f ┆ b │ diff --git a/py-polars/docs/source/reference/sql/set_operations.rst b/py-polars/docs/source/reference/sql/set_operations.rst index 7ea0e4d5d8e9..5f80a1b21264 100644 --- a/py-polars/docs/source/reference/sql/set_operations.rst +++ b/py-polars/docs/source/reference/sql/set_operations.rst @@ -7,6 +7,12 @@ Set Operations * - Function - Description + * - :ref:`EXCEPT ` + - Combine the result sets of two SELECT statements, returning only the rows + that appear in the first result set but not in the second. + * - :ref:`INTERSECT ` + - Combine the result sets of two SELECT statements, returning only the rows + that appear in both result sets. * - :ref:`UNION ` - Combine the distinct result sets of two or more SELECT statements. The final result set will have no duplicate rows. @@ -19,12 +25,12 @@ Set Operations will have no duplicate rows. This also combines columns from both datasets. -.. _union: +.. _except: -UNION ------ -Combine the distinct result sets of two or more SELECT statements. -The final result set will have no duplicate rows. +EXCEPT +------ +Combine the result sets of two SELECT statements, returning only the rows +that appear in the first result set but not in the second. **Example:** @@ -39,12 +45,62 @@ The final result set will have no duplicate rows. "age": [30, 25, 45], "name": ["Bob", "Charlie", "David"], }) - lf_union = pl.sql(""" - SELECT id, name FROM df1 - UNION - SELECT id, name FROM df2 + pl.sql(""" + SELECT id, name FROM lf1 + EXCEPT + SELECT id, name FROM lf2 + """).sort(by="id").collect() + # shape: (1, 2) + # ┌─────┬───────┐ + # │ id ┆ name │ + # │ --- ┆ --- │ + # │ i64 ┆ str │ + # ╞═════╪═══════╡ + # │ 1 ┆ Alice │ + # └─────┴───────┘ + +.. _intersect: + +INTERSECT +--------- +Combine the result sets of two SELECT statements, returning only the rows +that appear in both result sets. + +**Example:** + +.. code-block:: python + + pl.sql(""" + SELECT id, name FROM lf1 + INTERSECT + SELECT id, name FROM lf2 """).sort(by="id").collect() + # shape: (2, 2) + # ┌─────┬─────────┐ + # │ id ┆ name │ + # │ --- ┆ --- │ + # │ i64 ┆ str │ + # ╞═════╪═════════╡ + # │ 2 ┆ Bob │ + # │ 3 ┆ Charlie │ + # └─────┴─────────┘ + +.. _union: + +UNION +----- +Combine the distinct result sets of two or more SELECT statements. +The final result set will have no duplicate rows. + +**Example:** + +.. code-block:: python + pl.sql(""" + SELECT id, name FROM lf1 + UNION + SELECT id, name FROM lf2 + """).sort(by="id").collect() # shape: (4, 2) # ┌─────┬─────────┐ # │ id ┆ name │ @@ -68,12 +124,11 @@ The final result set will be composed of all rows from each query. .. code-block:: python - lf_union_all = pl.sql(""" - SELECT id, name FROM df1 + pl.sql(""" + SELECT id, name FROM lf1 UNION ALL - SELECT id, name FROM df2 + SELECT id, name FROM lf2 """).sort(by="id").collect() - # shape: (6, 2) # ┌─────┬─────────┐ # │ id ┆ name │ @@ -100,12 +155,11 @@ will have no duplicate rows. This also combines columns from both datasets. .. code-block:: python - lf_union_by_name = pl.sql(""" - SELECT * FROM df1 + pl.sql(""" + SELECT * FROM lf1 UNION BY NAME - SELECT * FROM df2 + SELECT * FROM lf2 """).sort(by="id").collect() - # shape: (6, 3) # ┌─────┬─────────┬──────┐ # │ id ┆ name ┆ age │ diff --git a/py-polars/tests/unit/sql/test_array.py b/py-polars/tests/unit/sql/test_array.py index 75e8463e7b7a..62f9781f9296 100644 --- a/py-polars/tests/unit/sql/test_array.py +++ b/py-polars/tests/unit/sql/test_array.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + import pytest import polars as pl @@ -7,6 +9,40 @@ from polars.testing import assert_frame_equal +@pytest.mark.parametrize( + ("sort_order", "limit", "expected"), + [ + (None, None, [("a", ["x", "y"]), ("b", ["z", "X", "Y"])]), + ("ASC", None, [("a", ["x", "y"]), ("b", ["z", "Y", "X"])]), + ("DESC", None, [("a", ["y", "x"]), ("b", ["X", "Y", "z"])]), + ("ASC", 2, [("a", ["x", "y"]), ("b", ["z", "Y"])]), + ("DESC", 2, [("a", ["y", "x"]), ("b", ["X", "Y"])]), + ("ASC", 1, [("a", ["x"]), ("b", ["z"])]), + ("DESC", 1, [("a", ["y"]), ("b", ["X"])]), + ], +) +def test_array_agg(sort_order: str | None, limit: int | None, expected: Any) -> None: + order_by = "" if not sort_order else f" ORDER BY col0 {sort_order}" + limit_clause = "" if not limit else f" LIMIT {limit}" + + res = pl.sql(f""" + WITH data (col0, col1, col2) as ( + VALUES + (1,'a','x'), + (2,'a','y'), + (4,'b','z'), + (8,'b','X'), + (7,'b','Y') + ) + SELECT col1, ARRAY_AGG(col2{order_by}{limit_clause}) AS arrs + FROM data + GROUP BY col1 + ORDER BY col1 + """).collect() + + assert res.rows() == expected + + def test_array_literals() -> None: with pl.SQLContext(df=None, eager=True) as ctx: res = ctx.execute( @@ -119,6 +155,6 @@ def test_unnest_table_function_errors() -> None: with pytest.raises( SQLInterfaceError, - match="nested array literals are not yet supported", + match="nested array literals are not currently supported", ): pl.sql_expr("[[1,2,3]] AS nested") diff --git a/py-polars/tests/unit/sql/test_group_by.py b/py-polars/tests/unit/sql/test_group_by.py index 0ef152ccd5ce..e4cf20a62883 100644 --- a/py-polars/tests/unit/sql/test_group_by.py +++ b/py-polars/tests/unit/sql/test_group_by.py @@ -223,7 +223,7 @@ def test_group_by_errors() -> None: with pytest.raises( SQLSyntaxError, - match=r"negative ordinals values are invalid for GROUP BY; found -99", + match=r"negative ordinal values are invalid for GROUP BY; found -99", ): df.sql("SELECT a, SUM(b) FROM self GROUP BY -99, a") diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index 77d149ccd06c..69bf095e30e4 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -124,7 +124,7 @@ def test_round_ndigits_errors() -> None: df = pl.DataFrame({"n": [99.999]}) with pl.SQLContext(df=df, eager=True) as ctx: with pytest.raises( - SQLSyntaxError, match=r"invalid decimals value for ROUND \('!!'\)" + SQLSyntaxError, match=r"invalid value for ROUND decimals \('!!'\)" ): ctx.execute("SELECT ROUND(n,'!!') AS n FROM df") with pytest.raises( diff --git a/py-polars/tests/unit/sql/test_order_by.py b/py-polars/tests/unit/sql/test_order_by.py index 364beb5a7583..691d6895be7b 100644 --- a/py-polars/tests/unit/sql/test_order_by.py +++ b/py-polars/tests/unit/sql/test_order_by.py @@ -237,6 +237,6 @@ def test_order_by_errors() -> None: with pytest.raises( SQLSyntaxError, - match="negative ordinals values are invalid for ORDER BY; found -1", + match="negative ordinal values are invalid for ORDER BY; found -1", ): df.sql("SELECT * FROM self ORDER BY -1") diff --git a/py-polars/tests/unit/sql/test_set_ops.py b/py-polars/tests/unit/sql/test_set_ops.py new file mode 100644 index 000000000000..10b4bccea120 --- /dev/null +++ b/py-polars/tests/unit/sql/test_set_ops.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import pytest + +import polars as pl +from polars.exceptions import SQLInterfaceError +from polars.testing import assert_frame_equal + + +def test_except_intersect() -> None: + df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841 + df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]}) # noqa: F841 + + res_e = pl.sql("SELECT x, y, z FROM df1 EXCEPT SELECT * FROM df2", eager=True) + res_i = pl.sql("SELECT * FROM df1 INTERSECT SELECT x, y, z FROM df2", eager=True) + + assert sorted(res_e.rows()) == [(1, 2, 5), (9, 3, 5)] + assert sorted(res_i.rows()) == [(1, 4, 5)] + + res_e = pl.sql("SELECT * FROM df2 EXCEPT TABLE df1", eager=True) + res_i = pl.sql( + """ + SELECT * FROM df2 + INTERSECT + SELECT x::int8, y::int8, z::int8 + FROM (VALUES (1,2,5),(9,3,5),(1,4,5),(1,4,5)) AS df1(x,y,z) + """, + eager=True, + ) + assert sorted(res_e.rows()) == [(1, 2, 7), (9, None, 6)] + assert sorted(res_i.rows()) == [(1, 4, 5)] + + # check null behaviour of nulls + with pl.SQLContext( + tbl1=pl.DataFrame({"x": [2, 9, 1], "y": [2, None, 4]}), + tbl2=pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4]}), + ) as ctx: + res = ctx.execute("SELECT * FROM tbl1 EXCEPT SELECT * FROM tbl2", eager=True) + assert_frame_equal(pl.DataFrame({"x": [2], "y": [2]}), res) + + +@pytest.mark.parametrize("op", ["EXCEPT", "INTERSECT", "UNION"]) +def test_except_intersect_errors(op: str) -> None: + df1 = pl.DataFrame({"x": [1, 9, 1, 1], "y": [2, 3, 4, 4], "z": [5, 5, 5, 5]}) # noqa: F841 + df2 = pl.DataFrame({"x": [1, 9, 1], "y": [2, None, 4], "z": [7, 6, 5]}) # noqa: F841 + + if op != "UNION": + with pytest.raises( + SQLInterfaceError, + match=f"'{op} ALL' is not supported", + ): + pl.sql(f"SELECT * FROM df1 {op} ALL SELECT * FROM df2", eager=False) + + with pytest.raises( + SQLInterfaceError, + match=f"{op} requires equal number of columns in each table", + ): + pl.sql(f"SELECT x FROM df1 {op} SELECT x, y FROM df2", eager=False) + + +@pytest.mark.parametrize( + ("cols1", "cols2", "union_subtype", "expected"), + [ + ( + ["*"], + ["*"], + "", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + ( + ["*"], + ["frame2.*"], + "ALL", + [(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")], + ), + ( + ["frame1.*"], + ["c1", "c2"], + "DISTINCT", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + ( + ["*"], + ["c2", "c1"], + "ALL BY NAME", + [(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")], + ), + ( + ["c1", "c2"], + ["c2", "c1"], + "BY NAME", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + pytest.param( + ["c1", "c2"], + ["c2", "c1"], + "DISTINCT BY NAME", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + ], +) +def test_union( + cols1: list[str], + cols2: list[str], + union_subtype: str, + expected: list[tuple[int, str]], +) -> None: + with pl.SQLContext( + frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}), + frame2=pl.DataFrame({"c1": [2, 3], "c2": ["yy", "xx"]}), + eager=True, + ) as ctx: + query = f""" + SELECT {', '.join(cols1)} FROM frame1 + UNION {union_subtype} + SELECT {', '.join(cols2)} FROM frame2 + """ + assert sorted(ctx.execute(query).rows()) == expected diff --git a/py-polars/tests/unit/sql/test_strings.py b/py-polars/tests/unit/sql/test_strings.py index b27399e0290f..4d6e6c598986 100644 --- a/py-polars/tests/unit/sql/test_strings.py +++ b/py-polars/tests/unit/sql/test_strings.py @@ -98,12 +98,15 @@ def test_string_left_right_reverse() -> None: "r": ["de", "bc", "a", None], "rev": ["edcba", "cba", "a", None], } - for func, invalid in (("LEFT", "'xyz'"), ("RIGHT", "6.66")): + for func, invalid_arg, invalid_err in ( + ("LEFT", "'xyz'", '"xyz"'), + ("RIGHT", "6.66", "(dyn float: 6.66)"), + ): with pytest.raises( SQLSyntaxError, - match=f"invalid 'n_chars' for {func}: {invalid}", + match=rf"""invalid 'n_chars' for {func} \({invalid_err}\)""", ): - ctx.execute(f"""SELECT {func}(txt,{invalid}) FROM df""").collect() + ctx.execute(f"""SELECT {func}(txt,{invalid_arg}) FROM df""").collect() def test_string_left_negative_expr() -> None: @@ -349,7 +352,7 @@ def test_string_substr() -> None: with pytest.raises( SQLSyntaxError, - match="SUBSTR does not support negative length: -99", + match=r"SUBSTR does not support negative length \(-99\)", ): ctx.execute("SELECT SUBSTR(scol,2,-99) FROM df") diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index 045fe5343934..fd509f64377f 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -272,7 +272,7 @@ def test_timestamp_time_unit_errors() -> None: for prec in (0, 15): with pytest.raises( SQLSyntaxError, - match=f"invalid temporal type precision; expected 1-9, found {prec}", + match=rf"invalid temporal type precision \(expected 1-9, found {prec}\)", ): ctx.execute(f"SELECT ts::timestamp({prec}) FROM frame_data") diff --git a/py-polars/tests/unit/sql/test_union.py b/py-polars/tests/unit/sql/test_union.py deleted file mode 100644 index d6975b0f7bee..000000000000 --- a/py-polars/tests/unit/sql/test_union.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -import pytest - -import polars as pl - - -@pytest.mark.parametrize( - ("cols1", "cols2", "union_subtype", "expected"), - [ - ( - ["*"], - ["*"], - "", - [(1, "zz"), (2, "yy"), (3, "xx")], - ), - ( - ["*"], - ["frame2.*"], - "ALL", - [(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")], - ), - ( - ["frame1.*"], - ["c1", "c2"], - "DISTINCT", - [(1, "zz"), (2, "yy"), (3, "xx")], - ), - ( - ["*"], - ["c2", "c1"], - "ALL BY NAME", - [(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")], - ), - ( - ["c1", "c2"], - ["c2", "c1"], - "BY NAME", - [(1, "zz"), (2, "yy"), (3, "xx")], - ), - pytest.param( - ["c1", "c2"], - ["c2", "c1"], - "DISTINCT BY NAME", - [(1, "zz"), (2, "yy"), (3, "xx")], - ), - ], -) -def test_union( - cols1: list[str], - cols2: list[str], - union_subtype: str, - expected: list[tuple[int, str]], -) -> None: - with pl.SQLContext( - frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}), - frame2=pl.DataFrame({"c1": [2, 3], "c2": ["yy", "xx"]}), - eager=True, - ) as ctx: - query = f""" - SELECT {', '.join(cols1)} FROM frame1 - UNION {union_subtype} - SELECT {', '.join(cols2)} FROM frame2 - """ - assert sorted(ctx.execute(query).rows()) == expected diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index bcab57b6612c..caff43275a4c 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -184,6 +184,9 @@ def test_selector_by_name(df: pl.DataFrame) -> None: ).columns assert selected_cols == ["fgg"] + for missing_column in ("missing", "???"): + assert df.select(cs.by_name(missing_column, require_all=False)).columns == [] + # check "by_name & col" for selector_expr, expected in ( (cs.by_name("abc", "cde") & pl.col("ghi"), ["abc", "cde", "ghi"]),