You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by xu...@apache.org on 2022/10/12 15:18:25 UTC

[arrow-datafusion] branch master updated: Consolidate and better tests for expression re-rewriting / aliasing (#3727)

This is an automated email from the ASF dual-hosted git repository.

xudong963 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new e27e86baa Consolidate and better tests for expression re-rewriting / aliasing (#3727)
e27e86baa is described below

commit e27e86baa5a54dd4643055ee3b1521865d0977cd
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Oct 12 11:18:17 2022 -0400

    Consolidate and better tests for expression re-rewriting / aliasing (#3727)
---
 datafusion/expr/src/expr.rs                        |   4 +-
 datafusion/optimizer/src/type_coercion.rs          |  33 ++----
 .../optimizer/src/unwrap_cast_in_comparison.rs     |  40 +------
 datafusion/optimizer/src/utils.rs                  | 122 ++++++++++++++++++++-
 4 files changed, 136 insertions(+), 63 deletions(-)

diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index c131682a8..16c3da078 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -482,8 +482,8 @@ impl Expr {
     }
 
     /// Return `self AS name` alias expression
-    pub fn alias(self, name: &str) -> Expr {
-        Expr::Alias(Box::new(self), name.to_owned())
+    pub fn alias(self, name: impl Into<String>) -> Expr {
+        Expr::Alias(Box::new(self), name.into())
     }
 
     /// Return `self IN <list>` if `negated` is false, otherwise
diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs
index be4d70265..ad9314406 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -17,11 +17,12 @@
 
 //! Optimizer rule for type validation and coercion
 
+use crate::utils::rewrite_preserving_name;
 use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::DataType;
 use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
 use datafusion_expr::expr::Case;
-use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
+use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
 use datafusion_expr::logical_plan::Subquery;
 use datafusion_expr::type_coercion::binary::{coerce_types, comparison_coercion};
 use datafusion_expr::type_coercion::functions::data_types;
@@ -91,30 +92,13 @@ fn optimize_internal(
         schema: Arc::new(schema),
     };
 
-    let original_expr_names: Vec<Option<String>> = plan
-        .expressions()
-        .iter()
-        .map(|expr| expr.name().ok())
-        .collect();
-
     let new_expr = plan
         .expressions()
         .into_iter()
-        .zip(original_expr_names)
-        .map(|(expr, original_name)| {
-            let expr = expr.rewrite(&mut expr_rewrite)?;
-
+        .map(|expr| {
             // ensure aggregate names don't change:
             // https://github.com/apache/arrow-datafusion/issues/3555
-            if matches!(expr, Expr::AggregateFunction { .. }) {
-                if let Some((alias, name)) = original_name.zip(expr.name().ok()) {
-                    if alias != name {
-                        return Ok(expr.alias(&alias));
-                    }
-                }
-            }
-
-            Ok(expr)
+            rewrite_preserving_name(expr, &mut expr_rewrite)
         })
         .collect::<Result<Vec<_>>>()?;
 
@@ -815,7 +799,8 @@ mod test {
         let mut config = OptimizerConfig::default();
         let plan = rule.optimize(&plan, &mut config)?;
         assert_eq!(
-            "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n  EmptyRelation",
+            "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
+             \n  EmptyRelation",
             &format!("{:?}", plan)
         );
         // a in (1,4,8), a is decimal
@@ -833,7 +818,8 @@ mod test {
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
         let plan = rule.optimize(&plan, &mut config)?;
         assert_eq!(
-            "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n  EmptyRelation",
+            "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
+             \n  EmptyRelation",
             &format!("{:?}", plan)
         );
         Ok(())
@@ -931,7 +917,8 @@ mod test {
         let mut config = OptimizerConfig::default();
         let plan = rule.optimize(&plan, &mut config).unwrap();
         assert_eq!(
-            "Projection: a LIKE CAST(NULL AS Utf8)\n  EmptyRelation",
+            "Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL \
+             \n  EmptyRelation",
             &format!("{:?}", plan)
         );
 
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 8392e28a9..8d04357ad 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -18,12 +18,13 @@
 //! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type
 //! of expr can be added if needed.
 //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr.
+use crate::utils::rewrite_preserving_name;
 use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::{
     DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
 };
 use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
-use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
+use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
 use datafusion_expr::utils::from_plan;
 use datafusion_expr::{
     binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
@@ -97,47 +98,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
     let new_exprs = plan
         .expressions()
         .into_iter()
-        .map(|expr| {
-            let original_name = name_for_alias(&expr)?;
-            let expr = expr.rewrite(&mut expr_rewriter)?;
-            add_alias_if_changed(&original_name, expr)
-        })
+        .map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter))
         .collect::<Result<Vec<_>>>()?;
 
     from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
 }
 
-fn name_for_alias(expr: &Expr) -> Result<String> {
-    match expr {
-        Expr::Sort { expr, .. } => name_for_alias(expr),
-        expr => expr.name(),
-    }
-}
-
-fn add_alias_if_changed(original_name: &str, expr: Expr) -> Result<Expr> {
-    let new_name = name_for_alias(&expr)?;
-
-    if new_name == original_name {
-        return Ok(expr);
-    }
-
-    Ok(match expr {
-        Expr::Sort {
-            expr,
-            asc,
-            nulls_first,
-        } => {
-            let expr = add_alias_if_changed(original_name, *expr)?;
-            Expr::Sort {
-                expr: Box::new(expr),
-                asc,
-                nulls_first,
-            }
-        }
-        expr => expr.alias(original_name),
-    })
-}
-
 struct UnwrapCastExprRewriter {
     schema: DFSchemaRef,
 }
diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs
index d962dd7b4..a1174276d 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -20,6 +20,7 @@
 use crate::{OptimizerConfig, OptimizerRule};
 use datafusion_common::Result;
 use datafusion_common::{plan_err, Column, DFSchemaRef};
+use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
 use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
 use datafusion_expr::{
     and, col, combine_filters,
@@ -315,13 +316,63 @@ pub fn alias_cols(cols: &[Column]) -> Vec<Expr> {
         .collect()
 }
 
+/// Rewrites `expr` using `rewriter`, ensuring that the output has the
+/// same name as `expr` prior to rewrite, adding an alias if necessary.
+///
+/// This is important when optimzing plans to ensure the the output
+/// schema of plan nodes don't change after optimization
+pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
+where
+    R: ExprRewriter<Expr>,
+{
+    let original_name = name_for_alias(&expr)?;
+    let expr = expr.rewrite(rewriter)?;
+    add_alias_if_changed(original_name, expr)
+}
+
+/// Return the name to use for the specific Expr, recursing into
+/// `Expr::Sort` as appropriate
+fn name_for_alias(expr: &Expr) -> Result<String> {
+    match expr {
+        Expr::Sort { expr, .. } => name_for_alias(expr),
+        expr => expr.name(),
+    }
+}
+
+/// Ensure `expr` has the name name as `original_name` by adding an
+/// alias if necessary.
+fn add_alias_if_changed(original_name: String, expr: Expr) -> Result<Expr> {
+    let new_name = name_for_alias(&expr)?;
+
+    if new_name == original_name {
+        return Ok(expr);
+    }
+
+    Ok(match expr {
+        Expr::Sort {
+            expr,
+            asc,
+            nulls_first,
+        } => {
+            let expr = add_alias_if_changed(original_name, *expr)?;
+            Expr::Sort {
+                expr: Box::new(expr),
+                asc,
+                nulls_first,
+            }
+        }
+        expr => expr.alias(original_name),
+    })
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
     use arrow::datatypes::DataType;
     use datafusion_common::Column;
-    use datafusion_expr::{col, utils::expr_to_columns};
+    use datafusion_expr::{col, lit, utils::expr_to_columns};
     use std::collections::HashSet;
+    use std::ops::Add;
 
     #[test]
     fn test_collect_expr() -> Result<()> {
@@ -344,4 +395,73 @@ mod tests {
         assert!(accum.contains(&Column::from_name("a")));
         Ok(())
     }
+
+    #[test]
+    fn test_rewrite_preserving_name() {
+        test_rewrite(col("a"), col("a"));
+
+        test_rewrite(col("a"), col("b"));
+
+        // cast data types
+        test_rewrite(
+            col("a"),
+            Expr::Cast {
+                expr: Box::new(col("a")),
+                data_type: DataType::Int32,
+            },
+        );
+
+        // change literal type from i32 to i64
+        test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
+
+        // SortExpr a+1 ==> b + 2
+        test_rewrite(
+            Expr::Sort {
+                expr: Box::new(col("a").add(lit(1i32))),
+                asc: true,
+                nulls_first: false,
+            },
+            Expr::Sort {
+                expr: Box::new(col("b").add(lit(2i64))),
+                asc: true,
+                nulls_first: false,
+            },
+        );
+    }
+
+    /// rewrites `expr_from` to `rewrite_to` using
+    /// `rewrite_preserving_name` verifying the result is `expected_expr`
+    fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
+        struct TestRewriter {
+            rewrite_to: Expr,
+        }
+
+        impl ExprRewriter for TestRewriter {
+            fn mutate(&mut self, _: Expr) -> Result<Expr> {
+                Ok(self.rewrite_to.clone())
+            }
+        }
+
+        let mut rewriter = TestRewriter {
+            rewrite_to: rewrite_to.clone(),
+        };
+        let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();
+
+        let original_name = match &expr_from {
+            Expr::Sort { expr, .. } => expr.name(),
+            expr => expr.name(),
+        }
+        .unwrap();
+
+        let new_name = match &expr {
+            Expr::Sort { expr, .. } => expr.name(),
+            expr => expr.name(),
+        }
+        .unwrap();
+
+        assert_eq!(
+            original_name, new_name,
+            "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
+        )
+    }
 }