Skip to content

Commit

Permalink
refactor(rust): Branch earlier in binary type resolving (pola-rs#16685)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 3, 2024
1 parent 49addca commit 15981d7
Showing 1 changed file with 58 additions and 56 deletions.
114 changes: 58 additions & 56 deletions crates/polars-plan/src/logical_plan/conversion/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,68 +291,70 @@ pub(super) fn process_binary(
(Struct(_), Struct(_), _op) => return Ok(None),
_ => {},
}
let compare_cat_to_string = compares_cat_to_string(&type_left, &type_right, op);
let datetime_arithmetic = is_datetime_arithmetic(&type_left, &type_right, op);
let list_arithmetic = is_list_arithmetic(&type_left, &type_right, op);
str_numeric_arithmetic(&type_left, &type_right)?;

// Special path for list arithmetic
if list_arithmetic {
return process_list_arithmetic(
type_left, type_right, node_left, node_right, op, expr_arena,
);
}
// ensure we don't enter this branch for common numeric arithmetic.
if op.is_arithmetic() {
if !(type_left.is_numeric() && type_right.is_numeric()) {
str_numeric_arithmetic(&type_left, &type_right)?;

#[cfg(feature = "dtype-struct")]
{
let is_struct_numeric_arithmetic =
is_struct_numeric_arithmetic(&type_left, &type_right, op);
if is_struct_numeric_arithmetic {
return process_struct_numeric_arithmetic(
type_left, type_right, node_left, node_right, op, expr_arena,
);
if is_datetime_arithmetic(&type_left, &type_right, op) {
return Ok(None);
}
// Special path for list arithmetic
if is_list_arithmetic(&type_left, &type_right, op) {
return process_list_arithmetic(
type_left, type_right, node_left, node_right, op, expr_arena,
);
}
#[cfg(feature = "dtype-struct")]
{
let is_struct_numeric_arithmetic =
is_struct_numeric_arithmetic(&type_left, &type_right, op);
if is_struct_numeric_arithmetic {
return process_struct_numeric_arithmetic(
type_left, type_right, node_left, node_right, op, expr_arena,
);
}
}
}
} else if compares_cat_to_string(&type_left, &type_right, op) {
return Ok(None);
}

// All early return paths
if compare_cat_to_string || datetime_arithmetic {
Ok(None)
} else {
// Coerce types:
let st = unpack!(get_supertype(&type_left, &type_right));
let mut st = modify_supertype(st, left, right, &type_left, &type_right);
// Coerce types:
let st = unpack!(get_supertype(&type_left, &type_right));
let mut st = modify_supertype(st, left, right, &type_left, &type_right);

if is_cat_str_binary(&type_left, &type_right) {
st = String
}
if is_cat_str_binary(&type_left, &type_right) {
st = String
}

// only cast if the type is not already the super type.
// this can prevent an expensive flattening and subsequent aggregation
// in a group_by context. To be able to cast the groups need to be
// flattened
let new_node_left = if type_left != st {
expr_arena.add(AExpr::Cast {
expr: node_left,
data_type: st.clone(),
strict: false,
})
} else {
node_left
};
let new_node_right = if type_right != st {
expr_arena.add(AExpr::Cast {
expr: node_right,
data_type: st,
strict: false,
})
} else {
node_right
};
// only cast if the type is not already the super type.
// this can prevent an expensive flattening and subsequent aggregation
// in a group_by context. To be able to cast the groups need to be
// flattened
let new_node_left = if type_left != st {
expr_arena.add(AExpr::Cast {
expr: node_left,
data_type: st.clone(),
strict: false,
})
} else {
node_left
};
let new_node_right = if type_right != st {
expr_arena.add(AExpr::Cast {
expr: node_right,
data_type: st,
strict: false,
})
} else {
node_right
};

Ok(Some(AExpr::BinaryExpr {
left: new_node_left,
op,
right: new_node_right,
}))
}
Ok(Some(AExpr::BinaryExpr {
left: new_node_left,
op,
right: new_node_right,
}))
}

0 comments on commit 15981d7

Please sign in to comment.