You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/08/25 17:24:06 UTC

[arrow-datafusion] branch master updated: Use `ExprRewriter` in `pre_cast_lit_in_comparison` (#3260)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 5ee52d08e Use `ExprRewriter` in `pre_cast_lit_in_comparison` (#3260)
5ee52d08e is described below

commit 5ee52d08ec83e39da6f5e7b3567fa112d3f022d4
Author: Andy Grove <an...@gmail.com>
AuthorDate: Thu Aug 25 11:24:00 2022 -0600

    Use `ExprRewriter` in `pre_cast_lit_in_comparison` (#3260)
    
    * Use ExprRewriter in pre_cast_lit_in_comparison.rs
    
    * remove manual recursion and add a nested test case
---
 .../optimizer/src/pre_cast_lit_in_comparison.rs    | 172 +++++++++++++--------
 1 file changed, 107 insertions(+), 65 deletions(-)

diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
index 0c16f7921..68c738ca8 100644
--- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
+++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
@@ -20,6 +20,7 @@
 use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::DataType;
 use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
+use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
 use datafusion_expr::utils::from_plan;
 use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator};
 
@@ -74,79 +75,92 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
         .collect::<Result<Vec<_>>>()?;
 
     let schema = plan.schema();
+
+    let mut expr_rewriter = PreCastLitExprRewriter {
+        schema: schema.clone(),
+    };
+
     let new_exprs = plan
         .expressions()
         .into_iter()
-        .map(|expr| visit_expr(expr, schema))
+        .map(|expr| expr.rewrite(&mut expr_rewriter))
         .collect::<Result<Vec<_>>>()?;
 
     from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
 }
 
-// Visit all type of expr, if the current has child expr, the child expr needed to visit first.
-fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
-    // traverse the expr by dfs
-    match &expr {
-        Expr::BinaryExpr { left, op, right } => {
-            // dfs visit the left and right expr
-            let left = visit_expr(*left.clone(), schema)?;
-            let right = visit_expr(*right.clone(), schema)?;
-            let left_type = left.get_type(schema);
-            let right_type = right.get_type(schema);
-            // can't get the data type, just return the expr
-            if left_type.is_err() || right_type.is_err() {
-                return Ok(expr.clone());
-            }
-            let left_type = left_type.unwrap();
-            let right_type = right_type.unwrap();
-            if !left_type.eq(&right_type)
-                && is_support_data_type(&left_type)
-                && is_support_data_type(&right_type)
-                && is_comparison_op(op)
-            {
-                match (&left, &right) {
-                    (Expr::Literal(_), Expr::Literal(_)) => {
-                        // do nothing
-                    }
-                    (Expr::Literal(left_lit_value), _)
-                        if can_integer_literal_cast_to_type(
-                            left_lit_value,
-                            &right_type,
-                        )? =>
-                    {
-                        // cast the left literal to the right type
-                        return Ok(binary_expr(
-                            cast_to_other_scalar_expr(left_lit_value, &right_type)?,
-                            *op,
-                            right,
-                        ));
-                    }
-                    (_, Expr::Literal(right_lit_value))
-                        if can_integer_literal_cast_to_type(
-                            right_lit_value,
-                            &left_type,
-                        )
-                        .unwrap() =>
-                    {
-                        // cast the right literal to the left type
-                        return Ok(binary_expr(
-                            left,
-                            *op,
-                            cast_to_other_scalar_expr(right_lit_value, &left_type)?,
-                        ));
-                    }
-                    (_, _) => {
-                        // do nothing
-                    }
-                };
+struct PreCastLitExprRewriter {
+    schema: DFSchemaRef,
+}
+
+impl ExprRewriter for PreCastLitExprRewriter {
+    fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
+        Ok(RewriteRecursion::Continue)
+    }
+
+    fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        // traverse the expr by dfs
+        match &expr {
+            Expr::BinaryExpr { left, op, right } => {
+                let left = left.as_ref().clone();
+                let right = right.as_ref().clone();
+                let left_type = left.get_type(&self.schema);
+                let right_type = right.get_type(&self.schema);
+                // can't get the data type, just return the expr
+                if left_type.is_err() || right_type.is_err() {
+                    return Ok(expr.clone());
+                }
+                let left_type = left_type?;
+                let right_type = right_type?;
+                if !left_type.eq(&right_type)
+                    && is_support_data_type(&left_type)
+                    && is_support_data_type(&right_type)
+                    && is_comparison_op(op)
+                {
+                    match (&left, &right) {
+                        (Expr::Literal(_), Expr::Literal(_)) => {
+                            // do nothing
+                        }
+                        (Expr::Literal(left_lit_value), _)
+                            if can_integer_literal_cast_to_type(
+                                left_lit_value,
+                                &right_type,
+                            )? =>
+                        {
+                            // cast the left literal to the right type
+                            return Ok(binary_expr(
+                                cast_to_other_scalar_expr(left_lit_value, &right_type)?,
+                                *op,
+                                right,
+                            ));
+                        }
+                        (_, Expr::Literal(right_lit_value))
+                            if can_integer_literal_cast_to_type(
+                                right_lit_value,
+                                &left_type,
+                            )
+                            .unwrap() =>
+                        {
+                            // cast the right literal to the left type
+                            return Ok(binary_expr(
+                                left,
+                                *op,
+                                cast_to_other_scalar_expr(right_lit_value, &left_type)?,
+                            ));
+                        }
+                        (_, _) => {
+                            // do nothing
+                        }
+                    };
+                }
+                // return the new binary op
+                Ok(binary_expr(left, *op, right))
             }
-            // return the new binary op
-            Ok(binary_expr(left, *op, right))
+            // TODO: optimize in list
+            // Expr::InList { .. } => {}
+            // TODO: handle other expr type and dfs visit them
+            _ => Ok(expr),
         }
-        // TODO: optimize in list
-        // Expr::InList { .. } => {}
-        // TODO: handle other expr type and dfs visit them
-        _ => Ok(expr),
     }
 }
 
@@ -245,9 +259,10 @@ fn can_integer_literal_cast_to_type(
 
 #[cfg(test)]
 mod tests {
-    use crate::pre_cast_lit_in_comparison::visit_expr;
+    use crate::pre_cast_lit_in_comparison::PreCastLitExprRewriter;
     use arrow::datatypes::DataType;
     use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
+    use datafusion_expr::expr_rewriter::ExprRewritable;
     use datafusion_expr::{col, lit, Expr};
     use std::collections::HashMap;
     use std::sync::Arc;
@@ -292,8 +307,35 @@ mod tests {
         assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
     }
 
+    #[test]
+    fn aliased() {
+        let schema = expr_test_schema();
+        // c1 < INT64(16) -> c1 < cast(INT32(16))
+        // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16)
+        let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))).alias("x");
+        let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))).alias("x");
+        assert_eq!(optimize_test(expr_lt, &schema), expected);
+    }
+
+    #[test]
+    fn nested() {
+        let schema = expr_test_schema();
+        // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32)
+        // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32
+        let expr_lt = col("c1")
+            .lt(lit(ScalarValue::Int64(Some(16))))
+            .or(col("c1").gt(lit(ScalarValue::Int64(Some(32)))));
+        let expected = col("c1")
+            .lt(lit(ScalarValue::Int32(Some(16))))
+            .or(col("c1").gt(lit(ScalarValue::Int32(Some(32)))));
+        assert_eq!(optimize_test(expr_lt, &schema), expected);
+    }
+
     fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
-        visit_expr(expr, schema).unwrap()
+        let mut expr_rewriter = PreCastLitExprRewriter {
+            schema: schema.clone(),
+        };
+        expr.rewrite(&mut expr_rewriter).unwrap()
     }
 
     fn expr_test_schema() -> DFSchemaRef {