You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by wj...@apache.org on 2023/09/04 19:19:34 UTC
[arrow-datafusion] 02/04: implement inlist guarantee use
This is an automated email from the ASF dual-hosted git repository.
wjones127 pushed a commit to branch 6171-simplify-with-guarantee
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit caa738f591470295bd6a4af026b1ba9d292f86bb
Author: Will Jones <wi...@gmail.com>
AuthorDate: Mon Sep 4 10:48:11 2023 -0700
implement inlist guarantee use
---
.../src/simplify_expressions/guarantees.rs | 104 +++++++++++++--------
1 file changed, 66 insertions(+), 38 deletions(-)
diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
index 4e142ef280..0772eaab50 100644
--- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs
+++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
@@ -18,7 +18,7 @@
//! Logic to inject guarantees with expressions.
//!
use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue};
-use datafusion_expr::{lit, Between, BinaryExpr, Expr, Operator};
+use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr, Operator};
use std::collections::HashMap;
/// A bound on the value of an expression.
@@ -108,6 +108,11 @@ impl Guarantee {
fn less_than_or_eq(&self, value: &ScalarValue) -> bool {
self.max.bound <= *value
}
+
+ /// Whether the guarantee could contain the given value.
+ fn contains(&self, value: &ScalarValue) -> bool {
+ !self.less_than(value) && !self.greater_than(value)
+ }
}
impl From<&ScalarValue> for Guarantee {
@@ -237,6 +242,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
_ => return Ok(expr),
};
+ // TODO: can this be simplified?
if let Some(guarantee) = self.guarantees.get(col.as_ref()) {
match op {
Operator::Eq => {
@@ -339,7 +345,35 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
}
}
- // In list
+ Expr::InList(InList {
+ expr: inner,
+ list,
+ negated,
+ }) => {
+ if let Some(guarantee) = self.guarantees.get(inner.as_ref()) {
+ // Can remove items from the list that don't match the guarantee
+ let new_list: Vec<Expr> = list
+ .iter()
+ .filter(|item| {
+ if let Expr::Literal(item) = item {
+ guarantee.contains(item)
+ } else {
+ true
+ }
+ })
+ .cloned()
+ .collect();
+
+ Ok(Expr::InList(InList {
+ expr: inner.clone(),
+ list: new_list,
+ negated: *negated,
+ }))
+ } else {
+ Ok(expr)
+ }
+ }
+
_ => Ok(expr),
}
}
@@ -471,59 +505,53 @@ mod tests {
#[test]
fn test_in_list() {
let guarantees = vec![
- // x = 2
- (col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))),
- // 1 <= y < 10
+ // 1 <= x < 10
(
- col("y"),
+ col("x"),
Guarantee::new(
Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), false)),
Some(GuaranteeBound::new(ScalarValue::Int32(Some(10)), true)),
NullStatus::NeverNull,
),
),
- // z is null
- (col("z"), Guarantee::from(&ScalarValue::Null)),
];
let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
- // These cases should be simplified
+ // These cases should be simplified so the list doesn't contain any
+ // values the guarantee says are outside the range.
+ // (column_name, starting_list, negated, expected_list)
let cases = &[
- // x IN ()
- (col("x").in_list(vec![], false), false),
- // x IN (10, 11)
- (col("x").in_list(vec![lit(10), lit(11)], false), false),
- // x IN (10, 2)
- (col("x").in_list(vec![lit(10), lit(2)], false), true),
- // x NOT IN (10, 2)
- (col("x").in_list(vec![lit(10), lit(2)], true), false),
- // y IN (10, 11)
- (col("y").in_list(vec![lit(10), lit(11)], false), false),
- // y NOT IN (0, 22)
- (col("y").in_list(vec![lit(0), lit(22)], true), true),
- // z IN (10, 11)
- (col("z").in_list(vec![lit(10), lit(11)], false), false),
+ // x IN (9, 11) => x IN (9)
+ ("x", vec![9, 11], false, vec![9]),
+ // x IN (10, 2) => x IN (2)
+ ("x", vec![10, 2], false, vec![2]),
+ // x NOT IN (9, 11) => x NOT IN (9)
+ ("x", vec![9, 11], true, vec![9]),
+ // x NOT IN (0, 22) => x NOT IN ()
+ ("x", vec![0, 22], true, vec![]),
];
- for (expr, expected_value) in cases {
+ for (column_name, starting_list, negated, expected_list) in cases {
+ let expr = col(*column_name).in_list(
+ starting_list
+ .iter()
+ .map(|v| lit(ScalarValue::Int32(Some(*v))))
+ .collect(),
+ *negated,
+ );
let output = expr.clone().rewrite(&mut rewriter).unwrap();
+ let expected_list = expected_list
+ .iter()
+ .map(|v| lit(ScalarValue::Int32(Some(*v))))
+ .collect();
assert_eq!(
output,
- Expr::Literal(ScalarValue::Boolean(Some(*expected_value)))
+ Expr::InList(InList {
+ expr: Box::new(col(*column_name)),
+ list: expected_list,
+ negated: *negated,
+ })
);
}
-
- // These cases should be left as-is
- let cases = &[
- // y IN (10, 2)
- col("y").in_list(vec![lit(10), lit(2)], false),
- // y NOT IN (10, 2)
- col("y").in_list(vec![lit(10), lit(2)], true),
- ];
-
- for expr in cases {
- let output = expr.clone().rewrite(&mut rewriter).unwrap();
- assert_eq!(&output, expr);
- }
}
}