You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/10 21:04:07 UTC

(arrow-datafusion) branch main updated: Simplify ProjectionPushdown and make it more general (#8109)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e305bcf197 Simplify ProjectionPushdown and make it more general (#8109)
e305bcf197 is described below

commit e305bcf197509dfb5c40d392cafad28d79effe08
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Fri Nov 10 16:04:00 2023 -0500

    Simplify ProjectionPushdown and make it more general (#8109)
    
    * Simply expression rewrite in ProjectionPushdown, make more general
    
    * Do not use partial rewrites
    
    * Apply suggestions from code review
    
    Co-authored-by: Berkay Şahin <12...@users.noreply.github.com>
    Co-authored-by: Mehmet Ozan Kabak <oz...@gmail.com>
    
    * cargo fmt, update comments
    
    ---------
    
    Co-authored-by: Berkay Şahin <12...@users.noreply.github.com>
    Co-authored-by: Mehmet Ozan Kabak <oz...@gmail.com>
---
 datafusion/common/src/tree_node.rs                 |  13 ++
 .../src/physical_optimizer/projection_pushdown.rs  | 162 +++++++--------------
 2 files changed, 62 insertions(+), 113 deletions(-)

diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs
index d0ef507294..5da9636ffe 100644
--- a/datafusion/common/src/tree_node.rs
+++ b/datafusion/common/src/tree_node.rs
@@ -149,6 +149,19 @@ pub trait TreeNode: Sized {
         Ok(new_node)
     }
 
+    /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its
+    /// children and then itself(Postorder Traversal) using a mutable function, `F`.
+    /// When the `op` does not apply to a given node, it is left unchanged.
+    fn transform_up_mut<F>(self, op: &mut F) -> Result<Self>
+    where
+        F: FnMut(Self) -> Result<Transformed<Self>>,
+    {
+        let after_op_children = self.map_children(|node| node.transform_up_mut(op))?;
+
+        let new_node = op(after_op_children)?.into();
+        Ok(new_node)
+    }
+
     /// Transform the tree node using the given [TreeNodeRewriter]
     /// It performs a depth first walk of an node and its children.
     ///
diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
index 1849595561..8e50492ae5 100644
--- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
@@ -43,12 +43,9 @@ use arrow_schema::SchemaRef;
 use datafusion_common::config::ConfigOptions;
 use datafusion_common::tree_node::{Transformed, TreeNode};
 use datafusion_common::JoinSide;
-use datafusion_physical_expr::expressions::{
-    BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr,
-};
+use datafusion_physical_expr::expressions::Column;
 use datafusion_physical_expr::{
     Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
-    ScalarFunctionExpr,
 };
 use datafusion_physical_plan::union::UnionExec;
 
@@ -791,119 +788,58 @@ fn update_expr(
     projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
     sync_with_child: bool,
 ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
-    let expr_any = expr.as_any();
-    if let Some(column) = expr_any.downcast_ref::<Column>() {
-        if sync_with_child {
-            // Update the index of `column`:
-            Ok(Some(projected_exprs[column.index()].0.clone()))
-        } else {
-            // Determine how to update `column` to accommodate `projected_exprs`:
-            Ok(projected_exprs.iter().enumerate().find_map(
-                |(index, (projected_expr, alias))| {
-                    projected_expr.as_any().downcast_ref::<Column>().and_then(
-                        |projected_column| {
-                            column
-                                .name()
-                                .eq(projected_column.name())
-                                .then(|| Arc::new(Column::new(alias, index)) as _)
-                        },
-                    )
-                },
-            ))
-        }
-    } else if let Some(binary) = expr_any.downcast_ref::<BinaryExpr>() {
-        match (
-            update_expr(binary.left(), projected_exprs, sync_with_child)?,
-            update_expr(binary.right(), projected_exprs, sync_with_child)?,
-        ) {
-            (Some(left), Some(right)) => {
-                Ok(Some(Arc::new(BinaryExpr::new(left, *binary.op(), right))))
-            }
-            _ => Ok(None),
-        }
-    } else if let Some(cast) = expr_any.downcast_ref::<CastExpr>() {
-        update_expr(cast.expr(), projected_exprs, sync_with_child).map(|maybe_expr| {
-            maybe_expr.map(|expr| {
-                Arc::new(CastExpr::new(
-                    expr,
-                    cast.cast_type().clone(),
-                    Some(cast.cast_options().clone()),
-                )) as _
-            })
-        })
-    } else if expr_any.is::<Literal>() {
-        Ok(Some(expr.clone()))
-    } else if let Some(negative) = expr_any.downcast_ref::<NegativeExpr>() {
-        update_expr(negative.arg(), projected_exprs, sync_with_child).map(|maybe_expr| {
-            maybe_expr.map(|expr| Arc::new(NegativeExpr::new(expr)) as _)
-        })
-    } else if let Some(scalar_func) = expr_any.downcast_ref::<ScalarFunctionExpr>() {
-        scalar_func
-            .args()
-            .iter()
-            .map(|expr| update_expr(expr, projected_exprs, sync_with_child))
-            .collect::<Result<Option<Vec<_>>>>()
-            .map(|maybe_args| {
-                maybe_args.map(|new_args| {
-                    Arc::new(ScalarFunctionExpr::new(
-                        scalar_func.name(),
-                        scalar_func.fun().clone(),
-                        new_args,
-                        scalar_func.return_type(),
-                        scalar_func.monotonicity().clone(),
-                    )) as _
-                })
-            })
-    } else if let Some(case) = expr_any.downcast_ref::<CaseExpr>() {
-        update_case_expr(case, projected_exprs, sync_with_child)
-    } else {
-        Ok(None)
+    #[derive(Debug, PartialEq)]
+    enum RewriteState {
+        /// The expression is unchanged.
+        Unchanged,
+        /// Some part of the expression has been rewritten
+        RewrittenValid,
+        /// Some part of the expression has been rewritten, but some column
+        /// references could not be.
+        RewrittenInvalid,
     }
-}
 
-/// Updates the indices `case` refers to according to `projected_exprs`.
-fn update_case_expr(
-    case: &CaseExpr,
-    projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
-    sync_with_child: bool,
-) -> Result<Option<Arc<dyn PhysicalExpr>>> {
-    let new_case = case
-        .expr()
-        .map(|expr| update_expr(expr, projected_exprs, sync_with_child))
-        .transpose()?
-        .flatten();
-
-    let new_else = case
-        .else_expr()
-        .map(|expr| update_expr(expr, projected_exprs, sync_with_child))
-        .transpose()?
-        .flatten();
-
-    let new_when_then = case
-        .when_then_expr()
-        .iter()
-        .map(|(when, then)| {
-            Ok((
-                update_expr(when, projected_exprs, sync_with_child)?,
-                update_expr(then, projected_exprs, sync_with_child)?,
-            ))
-        })
-        .collect::<Result<Vec<_>>>()?
-        .into_iter()
-        .filter_map(|(maybe_when, maybe_then)| match (maybe_when, maybe_then) {
-            (Some(when), Some(then)) => Some((when, then)),
-            _ => None,
-        })
-        .collect::<Vec<_>>();
+    let mut state = RewriteState::Unchanged;
 
-    if new_when_then.len() != case.when_then_expr().len()
-        || case.expr().is_some() && new_case.is_none()
-        || case.else_expr().is_some() && new_else.is_none()
-    {
-        return Ok(None);
-    }
+    let new_expr = expr
+        .clone()
+        .transform_up_mut(&mut |expr: Arc<dyn PhysicalExpr>| {
+            if state == RewriteState::RewrittenInvalid {
+                return Ok(Transformed::No(expr));
+            }
+
+            let Some(column) = expr.as_any().downcast_ref::<Column>() else {
+                return Ok(Transformed::No(expr));
+            };
+            if sync_with_child {
+                state = RewriteState::RewrittenValid;
+                // Update the index of `column`:
+                Ok(Transformed::Yes(projected_exprs[column.index()].0.clone()))
+            } else {
+                // default to invalid, in case we can't find the relevant column
+                state = RewriteState::RewrittenInvalid;
+                // Determine how to update `column` to accommodate `projected_exprs`
+                projected_exprs
+                    .iter()
+                    .enumerate()
+                    .find_map(|(index, (projected_expr, alias))| {
+                        projected_expr.as_any().downcast_ref::<Column>().and_then(
+                            |projected_column| {
+                                column.name().eq(projected_column.name()).then(|| {
+                                    state = RewriteState::RewrittenValid;
+                                    Arc::new(Column::new(alias, index)) as _
+                                })
+                            },
+                        )
+                    })
+                    .map_or_else(
+                        || Ok(Transformed::No(expr)),
+                        |c| Ok(Transformed::Yes(c)),
+                    )
+            }
+        });
 
-    CaseExpr::try_new(new_case, new_when_then, new_else).map(|e| Some(Arc::new(e) as _))
+    new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e))
 }
 
 /// Creates a new [`ProjectionExec`] instance with the given child plan and