Skip to content

Commit

Permalink
fix: Be more lenient in interpreting input args for builtin window fu…
Browse files Browse the repository at this point in the history
…nctions (apache#11199)

* fix: Be more lenient in interpreting input args for builtin window functions

The built-in window functions Lag, Lead, NthValue, Ntile
accept integer arguments. However while they should
allow any integers, currently as they just use
ScalarValue's try_from to convert into an i64, they
actually only accept i64s. Any other argument, e.g. an  i32,
would be converted into a None and ignored.

Before - lag and lead would silently ignore the argument, ntile and nth_value would fail:
```
> SELECT id, lead(id, -1) OVER (ORDER BY id) AS correct, lead(id, arrow_cast(-1,'Int32')) OVER (ORDER BY id) as wrong from (values (1), (2)) as tbl(id);
+----+---------+-------+
| id | correct | wrong |
+----+---------+-------+
| 1  |         | 2     |
| 2  | 1       |       |
+----+---------+-------+

> SELECT id, lag(id, -1) OVER (ORDER BY id) AS correct, lag(id, arrow_cast(-1,'Int32')) OVER (ORDER BY id) as wrong from (values (1), (2)) as tbl(id);
+----+---------+-------+
| id | correct | wrong |
+----+---------+-------+
| 1  | 2       |       |
| 2  |         | 1     |
+----+---------+-------+

> SELECT id, nth_value(id, 2) OVER (ORDER BY id) AS correct, nth_value(id, arrow_cast(2,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id);
Execution error: Internal("Cannot convert Int32(2) to i64")

> SELECT id, ntile(2) OVER (ORDER BY id) AS correct, ntile(arrow_cast(2,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id);
Internal error: Cannot convert Int32(2) to i64.
This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker
```

After - all four produce expected results:
```
SELECT id, lead(id, -1) OVER (ORDER BY id) AS correct, lead(id, arrow_cast(-1,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id)
+----+---------+-----------+
| id | correct | corrected |
+----+---------+-----------+
| 1  |         |           |
| 2  | 1       | 1         |
+----+---------+-----------+

SELECT id, lag(id, -1) OVER (ORDER BY id) AS correct, lag(id, arrow_cast(-1,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id)
+----+---------+-----------+
| id | correct | corrected |
+----+---------+-----------+
| 1  | 2       | 2         |
| 2  |         |           |
+----+---------+-----------+

SELECT id, nth_value(id, 2) OVER (ORDER BY id) AS correct, nth_value(id, arrow_cast(2,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id)
+----+---------+-----------+
| id | correct | corrected |
+----+---------+-----------+
| 1  |         |           |
| 2  | 2       | 2         |
+----+---------+-----------+

SELECT id, ntile(2) OVER (ORDER BY id) AS correct, ntile(arrow_cast(2,'Int32')) OVER (ORDER BY id) as corrected from (values (1), (2)) as tbl(id)
+----+---------+-----------+
| id | correct | corrected |
+----+---------+-----------+
| 1  | 1       | 1         |
| 2  | 2       | 2         |
+----+---------+-----------+
```

* cleanup

* make lead/lag throw if arg is invalid, check that the arg is int before casting, add tests

* return unsigned handling to ntile and move tests to sqllogictests window.slt

* remove unused import
  • Loading branch information
Blizzara authored Jul 3, 2024
1 parent c049a94 commit 6993561
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 12 deletions.
44 changes: 32 additions & 12 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,24 @@ fn get_scalar_value_from_args(
})
}

fn get_signed_integer(value: ScalarValue) -> Result<i64> {
if !value.data_type().is_integer() {
return Err(DataFusionError::Execution(
"Expected an integer value".to_string(),
));
}
value.cast_to(&DataType::Int64)?.try_into()
}

fn get_unsigned_integer(value: ScalarValue) -> Result<u64> {
if !value.data_type().is_integer() {
return Err(DataFusionError::Execution(
"Expected an integer value".to_string(),
));
}
value.cast_to(&DataType::UInt64)?.try_into()
}

fn get_casted_value(
default_value: Option<ScalarValue>,
dtype: &DataType,
Expand Down Expand Up @@ -259,10 +277,10 @@ fn create_built_in_window_expr(
}

if n.is_unsigned() {
let n: u64 = n.try_into()?;
let n = get_unsigned_integer(n)?;
Arc::new(Ntile::new(name, n, out_data_type))
} else {
let n: i64 = n.try_into()?;
let n: i64 = get_signed_integer(n)?;
if n <= 0 {
return exec_err!("NTILE requires a positive integer");
}
Expand All @@ -272,8 +290,8 @@ fn create_built_in_window_expr(
BuiltInWindowFunction::Lag => {
let arg = args[0].clone();
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(|v| v.try_into())
.and_then(|v| v.ok());
.map(get_signed_integer)
.map_or(Ok(None), |v| v.map(Some))?;
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
Arc::new(lag(
Expand All @@ -288,8 +306,8 @@ fn create_built_in_window_expr(
BuiltInWindowFunction::Lead => {
let arg = args[0].clone();
let shift_offset = get_scalar_value_from_args(args, 1)?
.map(|v| v.try_into())
.and_then(|v| v.ok());
.map(get_signed_integer)
.map_or(Ok(None), |v| v.map(Some))?;
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
Arc::new(lead(
Expand All @@ -303,11 +321,14 @@ fn create_built_in_window_expr(
}
BuiltInWindowFunction::NthValue => {
let arg = args[0].clone();
let n = args[1].as_any().downcast_ref::<Literal>().unwrap().value();
let n: i64 = n
.clone()
.try_into()
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
let n = get_signed_integer(
args[1]
.as_any()
.downcast_ref::<Literal>()
.unwrap()
.value()
.clone(),
)?;
Arc::new(NthValue::nth(
name,
arg,
Expand Down Expand Up @@ -618,7 +639,6 @@ mod tests {

use datafusion_functions_aggregate::count::count_udaf;
use futures::FutureExt;

use InputOrderMode::{Linear, PartiallySorted, Sorted};

fn create_test_schema() -> Result<SchemaRef> {
Expand Down
37 changes: 37 additions & 0 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4830,6 +4830,8 @@ NULL 3
NULL 2
NULL 1

statement ok
drop table t

### Test for window functions with arrays
statement ok
Expand All @@ -4852,3 +4854,38 @@ c [4, 5, 6] NULL

statement ok
drop table array_data

# Test for non-i64 offsets for NTILE, LAG, LEAD, NTH_VALUE
statement ok
CREATE TABLE t AS VALUES (3, 3), (4, 4), (5, 5), (6, 6);

query IIIIIIIII
SELECT
column1,
ntile(2) OVER (order by column1),
ntile(arrow_cast(2, 'Int32')) OVER (order by column1),
lag(column2, -1) OVER (order by column1),
lag(column2, arrow_cast(-1, 'Int32')) OVER (order by column1),
lead(column2, -1) OVER (order by column1),
lead(column2, arrow_cast(-1, 'Int32')) OVER (order by column1),
nth_value(column2, 2) OVER (order by column1),
nth_value(column2, arrow_cast(2, 'Int32')) OVER (order by column1)
FROM t;
----
3 1 1 4 4 NULL NULL NULL NULL
4 1 1 5 5 3 3 4 4
5 2 2 6 6 4 4 4 4
6 2 2 NULL NULL 5 5 4 4

# NTILE specifies the argument types so the error is different
query error
SELECT ntile(1.1) OVER (order by column1) FROM t;

query error DataFusion error: Execution error: Expected an integer value
SELECT lag(column2, 1.1) OVER (order by column1) FROM t;

query error DataFusion error: Execution error: Expected an integer value
SELECT lead(column2, 1.1) OVER (order by column1) FROM t;

query error DataFusion error: Execution error: Expected an integer value
SELECT nth_value(column2, 1.1) OVER (order by column1) FROM t;

0 comments on commit 6993561

Please sign in to comment.