Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable coercison for unmatched struct type #14409

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 58 additions & 25 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use arrow::datatypes::{
};
use datafusion_common::types::NativeType;
use datafusion_common::{
exec_err, internal_err, plan_datafusion_err, plan_err, Diagnostic, Result, Span,
Spans,
exec_err, internal_err, plan_datafusion_err, plan_err, Diagnostic, HashMap, Result,
Span, Spans,
};
use itertools::Itertools;

Expand Down Expand Up @@ -595,7 +595,7 @@ fn type_union_resolution_coercion(

/// Handle type union resolution including struct type and others.
pub fn try_type_union_resolution(data_types: &[DataType]) -> Result<Vec<DataType>> {
let err = match try_type_union_resolution_with_struct(data_types) {
let err = match try_type_union_resolution_with_struct(data_types, false) {
Ok(struct_types) => return Ok(struct_types),
Err(e) => Some(e),
};
Expand All @@ -611,11 +611,16 @@ pub fn try_type_union_resolution(data_types: &[DataType]) -> Result<Vec<DataType
// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1"
pub fn try_type_union_resolution_with_struct(
data_types: &[DataType],
is_unique: bool,
) -> Result<Vec<DataType>> {
let mut keys_string: Option<String> = None;
for data_type in data_types {
if let DataType::Struct(fields) = data_type {
let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
let keys = fields
.iter()
.map(|f| f.name().to_owned())
.sorted()
.join(",");
if let Some(ref k) = keys_string {
if *k != keys {
return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
Expand All @@ -628,51 +633,75 @@ pub fn try_type_union_resolution_with_struct(
}
}

let mut struct_types: Vec<DataType> = if let DataType::Struct(fields) = &data_types[0]
let first_fields = if let DataType::Struct(fields) = &data_types[0] {
fields.clone()
} else {
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
};

let mut struct_types_map: HashMap<String, DataType> = if let DataType::Struct(
fields,
) = &data_types[0]
{
fields.iter().map(|f| f.data_type().to_owned()).collect()
fields
.iter()
.map(|f| (f.name().to_owned(), f.data_type().to_owned()))
.collect()
} else {
return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
};

for data_type in data_types.iter().skip(1) {
if let DataType::Struct(fields) = data_type {
let incoming_struct_types: Vec<DataType> =
fields.iter().map(|f| f.data_type().to_owned()).collect();
// The order of field is verified above
for (lhs_type, rhs_type) in
struct_types.iter_mut().zip(incoming_struct_types.iter())
{
if let Some(coerced_type) =
type_union_resolution_coercion(lhs_type, rhs_type)
{
*lhs_type = coerced_type;
for field in fields.iter() {
let field_name = field.name();
if let Some(existing_type) = struct_types_map.get_mut(field_name) {
if let Some(coerced_type) =
type_union_resolution_coercion(field.data_type(), existing_type)
{
*existing_type = coerced_type;
} else {
return exec_err!(
"Fail to find the coerced type for {} and {}",
field.data_type(),
existing_type
);
}
} else {
return exec_err!(
"Fail to find the coerced type for {} and {}",
lhs_type,
rhs_type
);
return exec_err!("Field {} not found in first struct", field_name);
}
}
} else {
return exec_err!("Expect to get struct but got {}", data_type);
}
}

if is_unique {
let new_fields =
first_fields
.iter()
.map(|f| {
Arc::new(Arc::unwrap_or_clone(Arc::clone(f)).with_data_type(
struct_types_map.get(f.name()).unwrap().to_owned(),
))
})
.collect();
let unified_struct = DataType::Struct(new_fields);
return Ok(vec![unified_struct; data_types.len()]);
}

let mut final_struct_types = vec![];
for s in data_types {
let mut new_fields = vec![];
if let DataType::Struct(fields) = s {
for (i, f) in fields.iter().enumerate() {
for f in fields.iter() {
let field = Arc::unwrap_or_clone(Arc::clone(f))
.with_data_type(struct_types[i].to_owned());
.with_data_type(struct_types_map.get(f.name()).unwrap().to_owned()); // we can unwrap here since all fields are in the map
new_fields.push(Arc::new(field));
}
}
final_struct_types.push(DataType::Struct(new_fields.into()))
}

Ok(final_struct_types)
}

Expand Down Expand Up @@ -971,7 +1000,11 @@ fn struct_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Struct(lhs_fields), Struct(rhs_fields)) => {
if lhs_fields.len() != rhs_fields.len() {
if lhs_fields.len() != rhs_fields.len() || {
let l = lhs_fields.iter().map(|f| f.name()).sorted().join(",");
let r = rhs_fields.iter().map(|f| f.name()).sorted().join(",");
l != r
} {
return None;
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-nested/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl ScalarUDFImpl for MakeArray {

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let mut errors = vec![];
match try_type_union_resolution_with_struct(arg_types) {
match try_type_union_resolution_with_struct(arg_types, true) {
Ok(r) => return Ok(r),
Err(e) => {
errors.push(e);
Expand Down
28 changes: 28 additions & 0 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,34 @@ FROM t;
statement ok
drop table t

statement ok
create table t as values
(
{ 'foo': 'baz' },
{ 'xxx': arrow_cast('blarg', 'Utf8View') }
);

query error Failed to coerce then
select CASE WHEN 1=2 THEN column1 ELSE column2 END from t ;

statement ok
drop table t

statement ok
create table t as values
(
{ 'name': 'Alice', 'age': 25 },
{ 'age': 30, 'name': 'Bob' }
);

query ?
select CASE WHEN 1=2 THEN column1 ELSE column2 END from t;
----
{age: 30, name: Bob}

statement ok
drop table t

# Fix coercion of lists of structs
# https://github.com/apache/datafusion/issues/14154

Expand Down
23 changes: 23 additions & 0 deletions datafusion/sqllogictest/test_files/coalesce.slt
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,26 @@ Date32

statement ok
drop table test


statement ok
create table t as values
(
{ 'name': 'Alice', 'age': 20, 'id': 1},
{ 'age': 30, 'id': 2, 'name': 'Bob'},
{ 'name': 'Carol', 'id': 3, 'age': 22},
{ 'name': 'Carol', 'id': 3, 'age': 22, 'name':'test'}
);

query ?
select coalesce(column1, column2, column3) from t;
----
{name: Alice, age: 20, id: 1}

query ?
select coalesce(column1, column3) from t;
----
{name: Alice, age: 20, id: 1}

query error User-defined coercion failed
select coalesce(column1, column4) from t;
25 changes: 25 additions & 0 deletions datafusion/sqllogictest/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -851,3 +851,28 @@ FROM (
----
NULL false
foo true

statement ok
drop table t

statement ok
create table t as values
(
{ 'foo': 'baz' },
{ 'xxx': arrow_cast('blarg', 'Utf8View') },
{ 'name': 'Alice', 'age': 20 },
{ 'age': 30, 'name': 'Bob' }
);

query error Incompatible inputs for
select column1 from t UNION ALL select column2 from t;


query ?
select column3 from t UNION ALL select column4 from t order by column3;
----
{name: 30, age: Bob}
{name: Alice, age: 20}

statement ok
drop table t