You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by wj...@apache.org on 2023/09/04 19:19:34 UTC

[arrow-datafusion] 02/04: implement inlist guarantee use

This is an automated email from the ASF dual-hosted git repository.

wjones127 pushed a commit to branch 6171-simplify-with-guarantee
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit caa738f591470295bd6a4af026b1ba9d292f86bb
Author: Will Jones <wi...@gmail.com>
AuthorDate: Mon Sep 4 10:48:11 2023 -0700

    implement inlist guarantee use
---
 .../src/simplify_expressions/guarantees.rs         | 104 +++++++++++++--------
 1 file changed, 66 insertions(+), 38 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
index 4e142ef280..0772eaab50 100644
--- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs
+++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
@@ -18,7 +18,7 @@
 //! Logic to inject guarantees with expressions.
 //!
 use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue};
-use datafusion_expr::{lit, Between, BinaryExpr, Expr, Operator};
+use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr, Operator};
 use std::collections::HashMap;
 
 /// A bound on the value of an expression.
@@ -108,6 +108,11 @@ impl Guarantee {
     fn less_than_or_eq(&self, value: &ScalarValue) -> bool {
         self.max.bound <= *value
     }
+
+    /// Whether the guarantee could contain the given value.
+    fn contains(&self, value: &ScalarValue) -> bool {
+        !self.less_than(value) && !self.greater_than(value)
+    }
 }
 
 impl From<&ScalarValue> for Guarantee {
@@ -237,6 +242,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
                     _ => return Ok(expr),
                 };
 
+                // TODO: can this be simplified?
                 if let Some(guarantee) = self.guarantees.get(col.as_ref()) {
                     match op {
                         Operator::Eq => {
@@ -339,7 +345,35 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
                 }
             }
 
-            // In list
+            Expr::InList(InList {
+                expr: inner,
+                list,
+                negated,
+            }) => {
+                if let Some(guarantee) = self.guarantees.get(inner.as_ref()) {
+                    // Can remove items from the list that don't match the guarantee
+                    let new_list: Vec<Expr> = list
+                        .iter()
+                        .filter(|item| {
+                            if let Expr::Literal(item) = item {
+                                guarantee.contains(item)
+                            } else {
+                                true
+                            }
+                        })
+                        .cloned()
+                        .collect();
+
+                    Ok(Expr::InList(InList {
+                        expr: inner.clone(),
+                        list: new_list,
+                        negated: *negated,
+                    }))
+                } else {
+                    Ok(expr)
+                }
+            }
+
             _ => Ok(expr),
         }
     }
@@ -471,59 +505,53 @@ mod tests {
     #[test]
     fn test_in_list() {
         let guarantees = vec![
-            // x = 2
-            (col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))),
-            // 1 <= y < 10
+            // 1 <= x < 10
             (
-                col("y"),
+                col("x"),
                 Guarantee::new(
                     Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), false)),
                     Some(GuaranteeBound::new(ScalarValue::Int32(Some(10)), true)),
                     NullStatus::NeverNull,
                 ),
             ),
-            // z is null
-            (col("z"), Guarantee::from(&ScalarValue::Null)),
         ];
         let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
 
-        // These cases should be simplified
+        // These cases should be simplified so the list doesn't contain any
+        // values the guarantee says are outside the range.
+        // (column_name, starting_list, negated, expected_list)
         let cases = &[
-            // x IN ()
-            (col("x").in_list(vec![], false), false),
-            // x IN (10, 11)
-            (col("x").in_list(vec![lit(10), lit(11)], false), false),
-            // x IN (10, 2)
-            (col("x").in_list(vec![lit(10), lit(2)], false), true),
-            // x NOT IN (10, 2)
-            (col("x").in_list(vec![lit(10), lit(2)], true), false),
-            // y IN (10, 11)
-            (col("y").in_list(vec![lit(10), lit(11)], false), false),
-            // y NOT IN (0, 22)
-            (col("y").in_list(vec![lit(0), lit(22)], true), true),
-            // z IN (10, 11)
-            (col("z").in_list(vec![lit(10), lit(11)], false), false),
+            // x IN (9, 11) => x IN (9)
+            ("x", vec![9, 11], false, vec![9]),
+            // x IN (10, 2) => x IN (2)
+            ("x", vec![10, 2], false, vec![2]),
+            // x NOT IN (9, 11) => x NOT IN (9)
+            ("x", vec![9, 11], true, vec![9]),
+            // x NOT IN (0, 22) => x NOT IN ()
+            ("x", vec![0, 22], true, vec![]),
         ];
 
-        for (expr, expected_value) in cases {
+        for (column_name, starting_list, negated, expected_list) in cases {
+            let expr = col(*column_name).in_list(
+                starting_list
+                    .iter()
+                    .map(|v| lit(ScalarValue::Int32(Some(*v))))
+                    .collect(),
+                *negated,
+            );
             let output = expr.clone().rewrite(&mut rewriter).unwrap();
+            let expected_list = expected_list
+                .iter()
+                .map(|v| lit(ScalarValue::Int32(Some(*v))))
+                .collect();
             assert_eq!(
                 output,
-                Expr::Literal(ScalarValue::Boolean(Some(*expected_value)))
+                Expr::InList(InList {
+                    expr: Box::new(col(*column_name)),
+                    list: expected_list,
+                    negated: *negated,
+                })
             );
         }
-
-        // These cases should be left as-is
-        let cases = &[
-            // y IN (10, 2)
-            col("y").in_list(vec![lit(10), lit(2)], false),
-            // y NOT IN (10, 2)
-            col("y").in_list(vec![lit(10), lit(2)], true),
-        ];
-
-        for expr in cases {
-            let output = expr.clone().rewrite(&mut rewriter).unwrap();
-            assert_eq!(&output, expr);
-        }
     }
 }