From fcf94751eb7b47d7cfc4e3d5818474b7cc207b4b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 14 Oct 2024 11:26:42 -0400 Subject: [PATCH] Modifications need to compile against latest DF --- Cargo.lock | 1 + Cargo.toml | 1 + src/udf.rs | 2 +- src/udwf.rs | 19 +++++++++++++++---- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7700a162..33d144ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1162,6 +1162,7 @@ dependencies = [ "async-trait", "datafusion", "datafusion-ffi", + "datafusion-functions-window-common", "datafusion-proto", "datafusion-substrait", "futures", diff --git a/Cargo.toml b/Cargo.toml index a8b62cc1..a27d0ec9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ datafusion = { git = "https://github.com/timsaucer/datafusion.git", rev = "20756 datafusion-substrait = { git = "https://github.com/timsaucer/datafusion.git", rev = "20756df736006253f1ce9e94385b75ab44e268f8", optional = true } datafusion-proto = { git = "https://github.com/timsaucer/datafusion.git", rev = "20756df736006253f1ce9e94385b75ab44e268f8" } datafusion-ffi = { git = "https://github.com/timsaucer/datafusion.git", rev = "20756df736006253f1ce9e94385b75ab44e268f8" } +datafusion-functions-window-common = { git = "https://github.com/timsaucer/datafusion.git", rev = "20756df736006253f1ce9e94385b75ab44e268f8" } prost = "0.13" # keep in line with `datafusion-substrait` uuid = { version = "1.9", features = ["v4"] } mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] } diff --git a/src/udf.rs b/src/udf.rs index ec8efb16..ea56930e 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -97,7 +97,7 @@ impl PyScalarUDF { let function = create_udf( name, input_types.0, - Arc::new(return_type.0), + return_type.0, parse_volatility(volatility)?, to_scalar_function_impl(func), ); diff --git a/src/udwf.rs b/src/udwf.rs index 43c21ec7..106a8314 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use arrow::array::{make_array, Array, ArrayData, ArrayRef}; use datafusion::logical_expr::window_state::WindowAggState; use datafusion::scalar::ScalarValue; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -299,11 +300,21 @@ impl WindowUDFImpl for MultiColumnWindowUDF { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(self.return_type.clone()) + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + (self.partition_evaluator_factory)() } - fn partition_evaluator(&self) -> Result> { - (self.partition_evaluator_factory)() + fn field( + &self, + field_args: datafusion::logical_expr::function::WindowUDFFieldArgs, + ) -> Result { + Ok(arrow::datatypes::Field::new( + field_args.name(), + self.return_type.clone(), + true, + )) } }