Skip to content

Commit

Permalink
make function names versioned
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 14, 2023
1 parent f51916d commit 9ff2682
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
12 changes: 9 additions & 3 deletions example/derive_expression/expression_lib/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use polars::prelude::*;
use polars_plan::dsl::FieldsMapper;
use pyo3_polars::derive::{polars_expr, CallerContext};
use pyo3_polars::export::polars_core::POOL;
use serde::Deserialize;
use std::fmt::Write;
use pyo3_polars::export::polars_core::POOL;

#[derive(Deserialize)]
struct PigLatinKwargs {
Expand Down Expand Up @@ -54,7 +54,11 @@ fn split_offsets(len: usize, n: usize) -> Vec<(usize, usize)> {

/// This expression will run in parallel if the `context` allows it.
#[polars_expr(output_type=Utf8)]
fn pig_latinnify_with_paralellism(inputs: &[Series], context: CallerContext, kwargs: PigLatinKwargs) -> PolarsResult<Series> {
fn pig_latinnify_with_paralellism(
inputs: &[Series],
context: CallerContext,
kwargs: PigLatinKwargs,
) -> PolarsResult<Series> {
use rayon::prelude::*;
let ca = inputs[0].utf8()?;

Expand All @@ -67,7 +71,9 @@ fn pig_latinnify_with_paralellism(inputs: &[Series], context: CallerContext, kwa
.into_par_iter()
.map(|(offset, len)| {
let sliced = ca.slice(offset as i64, len);
let out = sliced.apply_to_buffer(|value, output| pig_latin_str(value, kwargs.capitalize, output));
let out = sliced.apply_to_buffer(|value, output| {
pig_latin_str(value, kwargs.capitalize, output)
});
out.downcast_iter().cloned().collect::<Vec<_>>()
})
.collect();
Expand Down
47 changes: 31 additions & 16 deletions pyo3-polars-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ fn quote_process_results() -> proc_macro2::TokenStream {
quote!(match result {
Ok(out) => {
// Update return value.
*return_value = polars_ffi::export_series(&out);
*return_value = polars_ffi::version_0::export_series(&out);
}
Err(err) => {
// Set latest error, but leave return value in empty state.
Expand All @@ -100,12 +100,12 @@ fn quote_process_results() -> proc_macro2::TokenStream {

fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
// count how often the user define a kwargs argument.
let args = ast
let args = ast
.sig
.inputs
.iter()
.skip(1)
.map(| fn_arg| {
.map(|fn_arg| {
if let FnArg::Typed(pat) = fn_arg {
if let syn::Pat::Ident(pat) = pat.pat.as_ref() {
pat.ident.to_string()
Expand All @@ -127,17 +127,18 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
1 => match args[0].as_str() {
"kwargs" => quote_call_kwargs(&ast, fn_name),
"context" => quote_call_context(&ast, fn_name),
a => panic!("didn't expect argument {}", a)
a => panic!("didn't expect argument {}", a),
},
2 => match (args[0].as_str(), args[1].as_str()) {
("context", "kwargs") => quote_call_context_kwargs(&ast, fn_name),
("kwargs", "context") => panic!("'kwargs', 'context' order should be reversed"),
(a, b) => panic!("didn't expect arguments {}, {}", a, b)
}
_ => panic!("didn't expect so many arguments")
(a, b) => panic!("didn't expect arguments {}, {}", a, b),
},
_ => panic!("didn't expect so many arguments"),
};

let quote_process_result = quote_process_results();
let fn_name = get_expression_function_name(fn_name);

quote!(
use pyo3_polars::export::*;
Expand All @@ -147,15 +148,15 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {
// create the outer public function
#[no_mangle]
pub unsafe extern "C" fn #fn_name (
e: *mut polars_ffi::SeriesExport,
e: *mut polars_ffi::version_0::SeriesExport,
input_len: usize,
kwargs_ptr: *const u8,
kwargs_len: usize,
return_value: *mut polars_ffi::SeriesExport,
context: *mut polars_ffi::CallerContext
return_value: *mut polars_ffi::version_0::SeriesExport,
context: *mut polars_ffi::version_0::CallerContext
) {
let panic_result = std::panic::catch_unwind(move || {
let inputs = polars_ffi::import_series_buffer(e, input_len).unwrap();
let inputs = polars_ffi::version_0::import_series_buffer(e, input_len).unwrap();

#quote_call

Expand All @@ -165,16 +166,30 @@ fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream {

if panic_result.is_err() {
// Set latest to panic and nullify return value;
*return_value = polars_ffi::SeriesExport::empty();
*return_value = polars_ffi::version_0::SeriesExport::empty();
pyo3_polars::derive::_set_panic();
}

}
)
}

fn get_field_name(fn_name: &syn::Ident) -> syn::Ident {
syn::Ident::new(&format!("__polars_field_{}", fn_name), fn_name.span())
fn get_field_function_name(fn_name: &syn::Ident) -> syn::Ident {
syn::Ident::new(
&format!(
"__polars_field_{}_v{}",
fn_name,
polars_ffi::get_version().0
),
fn_name.span(),
)
}

fn get_expression_function_name(fn_name: &syn::Ident) -> syn::Ident {
syn::Ident::new(
&format!("{}_v{}", fn_name, polars_ffi::get_version().0),
fn_name.span(),
)
}

fn get_inputs() -> proc_macro2::TokenStream {
Expand All @@ -192,7 +207,7 @@ fn create_field_function(
fn_name: &syn::Ident,
dtype_fn_name: &syn::Ident,
) -> proc_macro2::TokenStream {
let map_field_name = get_field_name(fn_name);
let map_field_name = get_field_function_name(fn_name);
let inputs = get_inputs();

quote! (
Expand Down Expand Up @@ -232,7 +247,7 @@ fn create_field_function_from_with_dtype(
fn_name: &syn::Ident,
dtype: syn::Ident,
) -> proc_macro2::TokenStream {
let map_field_name = get_field_name(fn_name);
let map_field_name = get_field_function_name(fn_name);
let inputs = get_inputs();

quote! (
Expand Down
2 changes: 1 addition & 1 deletion pyo3-polars/src/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::cell::RefCell;
use std::ffi::CString;

/// Gives the caller extra information on how to execute the expression.
pub use polars_ffi::CallerContext;
pub use polars_ffi::version_0::CallerContext;

/// A default opaque kwargs type.
pub type DefaultKwargs = serde_pickle::Value;
Expand Down

0 comments on commit 9ff2682

Please sign in to comment.