Skip to content

Commit

Permalink
Add support for InList expressions. (#418)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackson Newhouse authored Nov 22, 2023
1 parent e4eed97 commit bc962c5
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 1 deletion.
62 changes: 62 additions & 0 deletions arroyo-sql-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1246,4 +1246,66 @@ mod tests {
},
arroyo_types::from_millis(168565954000)
);

// TEST InList
single_test_codegen!(
"in_list",
"non_nullable_i32 IN (1, 2, 3, 4, 5, 6)",
arroyo_sql::TestStruct {
non_nullable_i32: 2,
..Default::default()
},
true
);

single_test_codegen!(
"in_list_false",
"non_nullable_i32 IN (1, 2, 3, 5, 6, 7)",
arroyo_sql::TestStruct {
non_nullable_i32: 4,
..Default::default()
},
false
);

single_test_codegen!(
"in_list_nullable",
"nullable_i32 IN (1, 2, 3, 4, 5)",
arroyo_sql::TestStruct {
nullable_i32: Some(2),
..Default::default()
},
Some(true)
);

single_test_codegen!(
"in_list_nullable_false",
"nullable_i32 IN (1, 2, 3, 4, 5)",
arroyo_sql::TestStruct {
nullable_i32: None,
..Default::default()
},
None
);

// TEST NotInList
single_test_codegen!(
"not_in_list",
"non_nullable_i32 NOT IN (1, 2, 3, 5, 6)",
arroyo_sql::TestStruct {
non_nullable_i32: 4,
..Default::default()
},
true
);

single_test_codegen!(
"not_in_list_nullable",
"nullable_i32 NOT IN (NULL, 2, 3, 4, 5)",
arroyo_sql::TestStruct {
nullable_i32: None,
..Default::default()
},
None
);
}
79 changes: 78 additions & 1 deletion arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use arroyo_types::{DatePart, DateTruncPrecision};
use datafusion_common::ScalarValue;
use datafusion_expr::{
aggregate_function,
expr::{AggregateUDF, Alias, ScalarFunction, ScalarUDF, Sort},
expr::{AggregateUDF, Alias, InList, ScalarFunction, ScalarUDF, Sort},
type_coercion::aggregates::{avg_return_type, sum_return_type},
BinaryExpr, BuiltinScalarFunction, Expr, GetFieldAccess, TryCast,
};
Expand Down Expand Up @@ -505,6 +505,12 @@ impl Expression {
} => {
(&mut *array_expression).traverse_mut(context, f);
}
DataStructureFunction::InList { expr, list } => {
(&mut *expr).traverse_mut(context, f);
for e in list {
e.traverse_mut(context, f);
}
}
},
Expression::Json(e) => {
(&mut *e.json_string).traverse_mut(context, f);
Expand Down Expand Up @@ -981,6 +987,30 @@ impl<'a> ExpressionContext<'a> {
}))
}
},
Expr::InList(InList {
expr,
list,
negated,
}) => {
let in_list_expression = DataStructureFunction::InList {
expr: Box::new(self.compile_expr(expr)?),
list: list
.iter()
.map(|e| self.compile_expr(e))
.collect::<Result<Vec<_>>>()?,
};
if *negated {
Ok(BinaryComparisonExpression::new(
Box::new(Expression::Literal(LiteralExpression {
literal: ScalarValue::Boolean(Some(false)),
})),
datafusion_expr::Operator::Eq,
Box::new(Expression::DataStructure(in_list_expression)),
)?)
} else {
Ok(Expression::DataStructure(in_list_expression))
}
}
expression => {
bail!("expression {:?} not yet implemented", expression)
}
Expand Down Expand Up @@ -3279,6 +3309,10 @@ pub enum DataStructureFunction {
array_expression: Box<Expression>,
index: usize,
},
InList {
expr: Box<Expression>,
list: Vec<Expression>,
},
}

impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for DataStructureFunction {
Expand Down Expand Up @@ -3414,6 +3448,45 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for DataStructureFun
}),
}
}
DataStructureFunction::InList { expr, list } => {
let comparison_expr = expr.generate(input_context);
let comparison_nullable = expr.expression_type(input_context).is_optional();

let list_exprs: Vec<syn::Expr> = list
.iter()
.map(|term| {
let expr = term.generate(input_context);
let nullable = term.expression_type(input_context).is_optional();
let comparison_type = term.expression_type(input_context).return_type();
if nullable {
parse_quote!(
({let expr: #comparison_type = #expr; expr}
.map(|v| v == comparison_expr).unwrap_or(false))
)
} else {
parse_quote!(
(comparison_expr == #expr)
)
}
})
.collect();

if comparison_nullable {
parse_quote!({
match #comparison_expr{
Some(comparison_expr) => {
Some(#(#list_exprs)||*)
}
None => None,
}
})
} else {
parse_quote!({
let comparison_expr = #comparison_expr;
(#(#list_exprs)||*)
})
}
}
}
}
fn expression_type(&self, input_context: &ValuePointerContext) -> TypeDef {
Expand Down Expand Up @@ -3457,6 +3530,10 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for DataStructureFun
};
TypeDef::DataType(field.data_type().clone(), true)
}
DataStructureFunction::InList { expr, list: _ } => TypeDef::DataType(
DataType::Boolean,
expr.expression_type(input_context).is_optional(),
),
}
}
}
Expand Down

0 comments on commit bc962c5

Please sign in to comment.