You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2024/01/08 19:14:54 UTC

(arrow-datafusion) 01/01: clean

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch alamb/clean_rewrite
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit b832e43001f91eb6136e55da851b93a2756f4594
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Mon Jan 8 14:14:46 2024 -0500

    clean
---
 datafusion/optimizer/src/optimize_projections.rs | 71 +++++++++++-------------
 1 file changed, 33 insertions(+), 38 deletions(-)

diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs
index 1d4eda0bd2..319f738038 100644
--- a/datafusion/optimizer/src/optimize_projections.rs
+++ b/datafusion/optimizer/src/optimize_projections.rs
@@ -41,6 +41,8 @@ use datafusion_expr::{
 
 use hashbrown::HashMap;
 use itertools::{izip, Itertools};
+use datafusion_common::tree_node::{TreeNode, VisitRecursion};
+use datafusion_expr::utils::inspect_expr_pre;
 
 /// A rule for optimizing logical plans by removing unused columns/fields.
 ///
@@ -450,7 +452,8 @@ fn merge_consecutive_projections(proj: &Projection) -> Result<Option<Projection>
         .expr
         .iter()
         .map(|expr| rewrite_expr(expr, prev_projection))
-        .collect::<Result<Option<Vec<_>>>>()?;
+        .collect::<Option<Vec<_>>>();
+
     if let Some(new_exprs) = new_exprs {
         let new_exprs = new_exprs
             .into_iter()
@@ -532,46 +535,38 @@ macro_rules! rewrite_expr_with_check {
 /// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result.
 /// - `Ok(None)`: Signals that `expr` can not be rewritten.
 /// - `Err(error)`: An error occured during the function call.
-fn rewrite_expr(expr: &Expr, input: &Projection) -> Result<Option<Expr>> {
-    let result = match expr {
-        Expr::Column(col) => {
+fn rewrite_expr(expr: &Expr, input: &Projection) -> Option<Expr> {
+    let mut have_all_references = true;
+    expr.apply(|e| {
+        if let Expr::Column(col) = e {
+            if input.schema.index_of_column(&col).is_none() {
+                have_all_references = false;
+                return Ok(VisitRecursion::Stop);
+            }
+        }
+        Ok(VisitRecursion::Continue)
+    })
+        // closure above never returns err, so unwrap here is ok
+        .unwrap();
+
+    if !have_all_references {
+        return None;
+    }
+
+    let expr = expr.map_children(|e| {
+        let new_expr = if let Expr::Column(col) = e {
             // Find index of column:
             let idx = input.schema.index_of_column(col)?;
             input.expr[idx].clone()
-        }
-        Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new(
-            Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))),
-            binary.op,
-            Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))),
-        )),
-        Expr::Alias(alias) => Expr::Alias(Alias::new(
-            trim_expr(rewrite_expr_with_check!(&alias.expr, input)),
-            alias.relation.clone(),
-            alias.name.clone(),
-        )),
-        Expr::Literal(_) => expr.clone(),
-        Expr::Cast(cast) => {
-            let new_expr = rewrite_expr_with_check!(&cast.expr, input);
-            Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone()))
-        }
-        Expr::ScalarFunction(scalar_fn) => {
-            // TODO: Support UDFs.
-            let ScalarFunctionDefinition::BuiltIn(fun) = scalar_fn.func_def else {
-                return Ok(None);
-            };
-            return Ok(scalar_fn
-                .args
-                .iter()
-                .map(|expr| rewrite_expr(expr, input))
-                .collect::<Result<Option<_>>>()?
-                .map(|new_args| {
-                    Expr::ScalarFunction(ScalarFunction::new(fun, new_args))
-                }));
-        }
-        // Unsupported type for consecutive projection merge analysis.
-        _ => return Ok(None),
-    };
-    Ok(Some(result))
+        } else {
+            e
+        };
+        Ok(new_expr)
+    })
+        // previously checked that all columns are in the schema, so unwrap here is ok
+        .unwrap();
+
+    Some(expr)
 }
 
 /// Retrieves a set of outer-referenced columns by the given expression, `expr`.