Skip to content

Commit

Permalink
Ensure that math functions fulfil the ColumnarValue contract
Browse files Browse the repository at this point in the history
If all UDF arguments are scalars, so should be the result.
In most cases, such function calls will be contant-folded,
however if for whatever reason the are not optimized,
we want to avoid an error due to array length mismatch.
  • Loading branch information
joroKr21 committed Oct 14, 2024
1 parent 5391c98 commit 4919a25
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
14 changes: 12 additions & 2 deletions datafusion/expr-common/src/columnar_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

//! [`ColumnarValue`] represents the result of evaluating an expression.

use arrow::array::ArrayRef;
use arrow::array::NullArray;
use arrow::array::{Array, ArrayRef, NullArray};
use arrow::compute::{kernels, CastOptions};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::format::DEFAULT_CAST_OPTIONS;
Expand Down Expand Up @@ -218,6 +217,17 @@ impl ColumnarValue {
}
}
}

/// Converts an [`ArrayRef`] to a [`ColumnarValue`] based on the supplied arguments.
/// This is useful for scalar UDF implementations to fulfil their contract:
/// if all arguments are scalar values, the result should also be a scalar value.
pub fn from_args_and_result(args: &[Self], result: ArrayRef) -> Result<Self> {
if result.len() == 1 && args.iter().all(|arg| matches!(arg, Self::Scalar(_))) {
Ok(Self::Scalar(ScalarValue::try_from_array(&result, 0)?))
} else {
Ok(Self::Array(result))
}
}
}

#[cfg(test)]
Expand Down
16 changes: 8 additions & 8 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,8 @@ macro_rules! make_math_unary_udf {
$EVALUATE_BOUNDS(inputs)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

fn invoke(&self, col_args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(col_args)?;
let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => {
Arc::new(make_function_scalar_inputs_return_type!(
Expand All @@ -255,7 +254,8 @@ macro_rules! make_math_unary_udf {
)
}
};
Ok(ColumnarValue::Array(arr))

ColumnarValue::from_args_and_result(col_args, arr)
}
}
}
Expand Down Expand Up @@ -336,9 +336,8 @@ macro_rules! make_math_binary_udf {
$OUTPUT_ORDERING(input)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

fn invoke(&self, col_args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(col_args)?;
let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => Arc::new(make_function_inputs2!(
&args[0],
Expand All @@ -364,7 +363,8 @@ macro_rules! make_math_binary_udf {
)
}
};
Ok(ColumnarValue::Array(arr))

ColumnarValue::from_args_and_result(col_args, arr)
}
}
}
Expand Down

0 comments on commit 4919a25

Please sign in to comment.