You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by mn...@apache.org on 2023/05/24 10:31:35 UTC

[arrow-datafusion] branch main updated: refactor: split `CommonSubexprEliminate::try_optimize` (#6348)

This is an automated email from the ASF dual-hosted git repository.

mneumann pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 2a2f07380b refactor: split `CommonSubexprEliminate::try_optimize` (#6348)
2a2f07380b is described below

commit 2a2f07380b8bfe16b9fb4b0c0b3bc6cd90b0a1fc
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Wed May 24 12:31:28 2023 +0200

    refactor: split `CommonSubexprEliminate::try_optimize` (#6348)
    
    Having a single method w/ all optimization algorithms lets rustc+LLVM
    pick the largest possible stack size over all branches when compiled
    under the `dev` profile (e.g. for tests). So even if there's a deeply
    nested projection (e.g. for `tpcds_logical_q64`), we pay for the more
    complex aggregation optimization with every recursion.
    
    Splitting the method into sub-methods seems to fix this.
    
    Fixes #6277.
---
 .../optimizer/src/common_subexpr_eliminate.rs      | 460 +++++++++++----------
 1 file changed, 246 insertions(+), 214 deletions(-)

diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 6989ca5352..0f63ecc2cc 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -63,7 +63,7 @@ impl CommonSubexprEliminate {
         &self,
         exprs_list: &[&[Expr]],
         arrays_list: &[&[Vec<(usize, String)>]],
-        expr_set: &mut ExprSet,
+        expr_set: &ExprSet,
         affected_id: &mut BTreeSet<Identifier>,
     ) -> Result<Vec<Vec<Expr>>> {
         exprs_list
@@ -87,7 +87,7 @@ impl CommonSubexprEliminate {
         exprs_list: &[&[Expr]],
         arrays_list: &[&[Vec<(usize, String)>]],
         input: &LogicalPlan,
-        expr_set: &mut ExprSet,
+        expr_set: &ExprSet,
         config: &dyn OptimizerConfig,
     ) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
         let mut affected_id = BTreeSet::<Identifier>::new();
@@ -104,222 +104,253 @@ impl CommonSubexprEliminate {
 
         Ok((rewrite_exprs, new_input))
     }
-}
 
-impl OptimizerRule for CommonSubexprEliminate {
-    fn try_optimize(
+    fn try_optimize_projection(
         &self,
-        plan: &LogicalPlan,
+        projection: &Projection,
         config: &dyn OptimizerConfig,
-    ) -> Result<Option<LogicalPlan>> {
+    ) -> Result<LogicalPlan> {
+        let Projection {
+            expr,
+            input,
+            schema,
+            ..
+        } = projection;
+        let input_schema = Arc::clone(input.schema());
         let mut expr_set = ExprSet::new();
+        let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
 
-        let original_schema = plan.schema().clone();
-        let optimized_plan = match plan {
-            LogicalPlan::Projection(Projection {
-                expr,
-                input,
-                schema,
-                ..
-            }) => {
-                let input_schema = Arc::clone(input.schema());
-                let arrays =
-                    to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
-
-                let (mut new_expr, new_input) =
-                    self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?;
-
-                Some(LogicalPlan::Projection(Projection::try_new_with_schema(
-                    pop_expr(&mut new_expr)?,
-                    Arc::new(new_input),
-                    schema.clone(),
-                )?))
-            }
-            LogicalPlan::Filter(filter) => {
-                let predicate = &filter.predicate;
-                let input_schema = Arc::clone(filter.input.schema());
-                let mut id_array = vec![];
-                expr_to_identifier(
-                    predicate,
-                    &mut expr_set,
-                    &mut id_array,
-                    input_schema,
-                    ExprMask::Normal,
-                )?;
-
-                let (mut new_expr, new_input) = self.rewrite_expr(
-                    &[&[predicate.clone()]],
-                    &[&[id_array]],
-                    &filter.input,
-                    &mut expr_set,
-                    config,
-                )?;
-
-                if let Some(predicate) = pop_expr(&mut new_expr)?.pop() {
-                    Some(LogicalPlan::Filter(Filter::try_new(
-                        predicate,
-                        Arc::new(new_input),
-                    )?))
-                } else {
-                    return Err(DataFusionError::Internal(
-                        "Failed to pop predicate expr".to_string(),
-                    ));
+        let (mut new_expr, new_input) =
+            self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?;
+
+        Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
+            pop_expr(&mut new_expr)?,
+            Arc::new(new_input),
+            schema.clone(),
+        )?))
+    }
+
+    fn try_optimize_filter(
+        &self,
+        filter: &Filter,
+        config: &dyn OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        let mut expr_set = ExprSet::new();
+        let predicate = &filter.predicate;
+        let input_schema = Arc::clone(filter.input.schema());
+        let mut id_array = vec![];
+        expr_to_identifier(
+            predicate,
+            &mut expr_set,
+            &mut id_array,
+            input_schema,
+            ExprMask::Normal,
+        )?;
+
+        let (mut new_expr, new_input) = self.rewrite_expr(
+            &[&[predicate.clone()]],
+            &[&[id_array]],
+            &filter.input,
+            &expr_set,
+            config,
+        )?;
+
+        if let Some(predicate) = pop_expr(&mut new_expr)?.pop() {
+            Ok(LogicalPlan::Filter(Filter::try_new(
+                predicate,
+                Arc::new(new_input),
+            )?))
+        } else {
+            Err(DataFusionError::Internal(
+                "Failed to pop predicate expr".to_string(),
+            ))
+        }
+    }
+
+    fn try_optimize_window(
+        &self,
+        window: &Window,
+        config: &dyn OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        let Window {
+            input,
+            window_expr,
+            schema,
+        } = window;
+        let mut expr_set = ExprSet::new();
+
+        let input_schema = Arc::clone(input.schema());
+        let arrays =
+            to_arrays(window_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+
+        let (mut new_expr, new_input) =
+            self.rewrite_expr(&[window_expr], &[&arrays], input, &expr_set, config)?;
+
+        Ok(LogicalPlan::Window(Window {
+            input: Arc::new(new_input),
+            window_expr: pop_expr(&mut new_expr)?,
+            schema: schema.clone(),
+        }))
+    }
+
+    fn try_optimize_aggregate(
+        &self,
+        aggregate: &Aggregate,
+        config: &dyn OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        let Aggregate {
+            group_expr,
+            aggr_expr,
+            input,
+            schema,
+            ..
+        } = aggregate;
+        let mut expr_set = ExprSet::new();
+
+        // rewrite inputs
+        let input_schema = Arc::clone(input.schema());
+        let group_arrays = to_arrays(
+            group_expr,
+            Arc::clone(&input_schema),
+            &mut expr_set,
+            ExprMask::Normal,
+        )?;
+        let aggr_arrays =
+            to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+
+        let (mut new_expr, new_input) = self.rewrite_expr(
+            &[group_expr, aggr_expr],
+            &[&group_arrays, &aggr_arrays],
+            input,
+            &expr_set,
+            config,
+        )?;
+        // note the reversed pop order.
+        let new_aggr_expr = pop_expr(&mut new_expr)?;
+        let new_group_expr = pop_expr(&mut new_expr)?;
+
+        // create potential projection on top
+        let mut expr_set = ExprSet::new();
+        let new_input_schema = Arc::clone(new_input.schema());
+        let aggr_arrays = to_arrays(
+            &new_aggr_expr,
+            new_input_schema.clone(),
+            &mut expr_set,
+            ExprMask::NormalAndAggregates,
+        )?;
+        let mut affected_id = BTreeSet::<Identifier>::new();
+        let mut rewritten = self.rewrite_exprs_list(
+            &[&new_aggr_expr],
+            &[&aggr_arrays],
+            &expr_set,
+            &mut affected_id,
+        )?;
+        let rewritten = pop_expr(&mut rewritten)?;
+
+        if affected_id.is_empty() {
+            Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
+                Arc::new(new_input),
+                new_group_expr,
+                new_aggr_expr,
+                schema.clone(),
+            )?))
+        } else {
+            let mut agg_exprs = vec![];
+
+            for id in affected_id {
+                match expr_set.get(&id) {
+                    Some((expr, _, _)) => {
+                        // todo: check `nullable`
+                        agg_exprs.push(expr.clone().alias(&id));
+                    }
+                    _ => {
+                        return Err(DataFusionError::Internal(
+                            "expr_set invalid state".to_string(),
+                        ));
+                    }
                 }
             }
-            LogicalPlan::Window(Window {
-                input,
-                window_expr,
-                schema,
-            }) => {
-                let input_schema = Arc::clone(input.schema());
-                let arrays = to_arrays(
-                    window_expr,
-                    input_schema,
-                    &mut expr_set,
-                    ExprMask::Normal,
-                )?;
-
-                let (mut new_expr, new_input) = self.rewrite_expr(
-                    &[window_expr],
-                    &[&arrays],
-                    input,
-                    &mut expr_set,
-                    config,
-                )?;
-
-                Some(LogicalPlan::Window(Window {
-                    input: Arc::new(new_input),
-                    window_expr: pop_expr(&mut new_expr)?,
-                    schema: schema.clone(),
-                }))
+
+            let mut proj_exprs = vec![];
+            for expr in &new_group_expr {
+                let out_col: Column =
+                    expr.to_field(&new_input_schema)?.qualified_column();
+                proj_exprs.push(Expr::Column(out_col));
             }
-            LogicalPlan::Aggregate(Aggregate {
-                group_expr,
-                aggr_expr,
-                input,
-                schema,
-                ..
-            }) => {
-                // rewrite inputs
-                let input_schema = Arc::clone(input.schema());
-                let group_arrays = to_arrays(
-                    group_expr,
-                    Arc::clone(&input_schema),
-                    &mut expr_set,
-                    ExprMask::Normal,
-                )?;
-                let aggr_arrays =
-                    to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
-
-                let (mut new_expr, new_input) = self.rewrite_expr(
-                    &[group_expr, aggr_expr],
-                    &[&group_arrays, &aggr_arrays],
-                    input,
-                    &mut expr_set,
-                    config,
-                )?;
-                // note the reversed pop order.
-                let new_aggr_expr = pop_expr(&mut new_expr)?;
-                let new_group_expr = pop_expr(&mut new_expr)?;
-
-                // create potential projection on top
-                let mut expr_set = ExprSet::new();
-                let new_input_schema = Arc::clone(new_input.schema());
-                let aggr_arrays = to_arrays(
-                    &new_aggr_expr,
-                    new_input_schema.clone(),
-                    &mut expr_set,
-                    ExprMask::NormalAndAggregates,
-                )?;
-                let mut affected_id = BTreeSet::<Identifier>::new();
-                let mut rewritten = self.rewrite_exprs_list(
-                    &[&new_aggr_expr],
-                    &[&aggr_arrays],
-                    &mut expr_set,
-                    &mut affected_id,
-                )?;
-                let rewritten = pop_expr(&mut rewritten)?;
-
-                if affected_id.is_empty() {
-                    Some(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
-                        Arc::new(new_input),
-                        new_group_expr,
-                        new_aggr_expr,
-                        schema.clone(),
-                    )?))
-                } else {
-                    let mut agg_exprs = vec![];
-
-                    for id in affected_id {
-                        match expr_set.get(&id) {
-                            Some((expr, _, _)) => {
-                                // todo: check `nullable`
-                                agg_exprs.push(expr.clone().alias(&id));
-                            }
-                            _ => {
-                                return Err(DataFusionError::Internal(
-                                    "expr_set invalid state".to_string(),
-                                ));
-                            }
-                        }
+            for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) {
+                if expr_rewritten == expr_orig {
+                    if let Expr::Alias(expr, name) = expr_rewritten {
+                        agg_exprs.push(expr.alias(&name));
+                        proj_exprs.push(Expr::Column(Column::from_name(name)));
+                    } else {
+                        let id =
+                            ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten);
+                        let out_name =
+                            expr_rewritten.to_field(&new_input_schema)?.qualified_name();
+                        agg_exprs.push(expr_rewritten.alias(&id));
+                        proj_exprs
+                            .push(Expr::Column(Column::from_name(id)).alias(out_name));
                     }
+                } else {
+                    proj_exprs.push(expr_rewritten);
+                }
+            }
 
-                    let mut proj_exprs = vec![];
-                    for expr in &new_group_expr {
-                        let out_col: Column =
-                            expr.to_field(&new_input_schema)?.qualified_column();
-                        proj_exprs.push(Expr::Column(out_col));
-                    }
-                    for (expr_rewritten, expr_orig) in
-                        rewritten.into_iter().zip(new_aggr_expr)
-                    {
-                        if expr_rewritten == expr_orig {
-                            if let Expr::Alias(expr, name) = expr_rewritten {
-                                agg_exprs.push(expr.alias(&name));
-                                proj_exprs.push(Expr::Column(Column::from_name(name)));
-                            } else {
-                                let id = ExprIdentifierVisitor::<'static>::desc_expr(
-                                    &expr_rewritten,
-                                );
-                                let out_name = expr_rewritten
-                                    .to_field(&new_input_schema)?
-                                    .qualified_name();
-                                agg_exprs.push(expr_rewritten.alias(&id));
-                                proj_exprs.push(
-                                    Expr::Column(Column::from_name(id)).alias(out_name),
-                                );
-                            }
-                        } else {
-                            proj_exprs.push(expr_rewritten);
-                        }
-                    }
+            let agg = LogicalPlan::Aggregate(Aggregate::try_new(
+                Arc::new(new_input),
+                new_group_expr,
+                agg_exprs,
+            )?);
 
-                    let agg = LogicalPlan::Aggregate(Aggregate::try_new(
-                        Arc::new(new_input),
-                        new_group_expr,
-                        agg_exprs,
-                    )?);
+            Ok(LogicalPlan::Projection(Projection::try_new(
+                proj_exprs,
+                Arc::new(agg),
+            )?))
+        }
+    }
 
-                    Some(LogicalPlan::Projection(Projection::try_new(
-                        proj_exprs,
-                        Arc::new(agg),
-                    )?))
-                }
+    fn try_optimize_sort(
+        &self,
+        sort: &Sort,
+        config: &dyn OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        let Sort { expr, input, fetch } = sort;
+        let mut expr_set = ExprSet::new();
+
+        let input_schema = Arc::clone(input.schema());
+        let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+
+        let (mut new_expr, new_input) =
+            self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?;
+
+        Ok(LogicalPlan::Sort(Sort {
+            expr: pop_expr(&mut new_expr)?,
+            input: Arc::new(new_input),
+            fetch: *fetch,
+        }))
+    }
+}
+
+impl OptimizerRule for CommonSubexprEliminate {
+    fn try_optimize(
+        &self,
+        plan: &LogicalPlan,
+        config: &dyn OptimizerConfig,
+    ) -> Result<Option<LogicalPlan>> {
+        let optimized_plan = match plan {
+            LogicalPlan::Projection(projection) => {
+                Some(self.try_optimize_projection(projection, config)?)
             }
-            LogicalPlan::Sort(Sort { expr, input, fetch }) => {
-                let input_schema = Arc::clone(input.schema());
-                let arrays =
-                    to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
-
-                let (mut new_expr, new_input) =
-                    self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?;
-
-                Some(LogicalPlan::Sort(Sort {
-                    expr: pop_expr(&mut new_expr)?,
-                    input: Arc::new(new_input),
-                    fetch: *fetch,
-                }))
+            LogicalPlan::Filter(filter) => {
+                Some(self.try_optimize_filter(filter, config)?)
+            }
+            LogicalPlan::Window(window) => {
+                Some(self.try_optimize_window(window, config)?)
+            }
+            LogicalPlan::Aggregate(aggregate) => {
+                Some(self.try_optimize_aggregate(aggregate, config)?)
             }
+            LogicalPlan::Sort(sort) => Some(self.try_optimize_sort(sort, config)?),
             LogicalPlan::Join(_)
             | LogicalPlan::CrossJoin(_)
             | LogicalPlan::Repartition(_)
@@ -345,6 +376,7 @@ impl OptimizerRule for CommonSubexprEliminate {
             }
         };
 
+        let original_schema = plan.schema().clone();
         match optimized_plan {
             Some(optimized_plan) if optimized_plan.schema() != &original_schema => {
                 // add an additional projection if the output schema changed.
@@ -634,7 +666,7 @@ fn expr_to_identifier(
 /// the corresponding temporary column name. That column contains the
 /// evaluate result of replaced expression.
 struct CommonSubexprRewriter<'a> {
-    expr_set: &'a mut ExprSet,
+    expr_set: &'a ExprSet,
     id_array: &'a [(usize, Identifier)],
     /// Which identifier is replaced.
     affected_id: &'a mut BTreeSet<Identifier>,
@@ -717,7 +749,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
 fn replace_common_expr(
     expr: Expr,
     id_array: &[(usize, Identifier)],
-    expr_set: &mut ExprSet,
+    expr_set: &ExprSet,
     affected_id: &mut BTreeSet<Identifier>,
 ) -> Result<Expr> {
     expr.rewrite(&mut CommonSubexprRewriter {
@@ -783,14 +815,14 @@ mod test {
         )?;
 
         let expected = vec![
-            (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), 
-            (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), 
-            (4, ""), 
-            (3, "a + Int32(1)Int32(1)a"), 
-            (1, ""), 
-            (2, ""), 
-            (6, ""), 
-            (5, ""), 
+            (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
+            (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
+            (4, ""),
+            (3, "a + Int32(1)Int32(1)a"),
+            (1, ""),
+            (2, ""),
+            (6, ""),
+            (5, ""),
             (8, "")
         ]
         .into_iter()