You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by mn...@apache.org on 2023/05/24 10:31:35 UTC
[arrow-datafusion] branch main updated: refactor: split `CommonSubexprEliminate::try_optimize` (#6348)
This is an automated email from the ASF dual-hosted git repository.
mneumann pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2a2f07380b refactor: split `CommonSubexprEliminate::try_optimize` (#6348)
2a2f07380b is described below
commit 2a2f07380b8bfe16b9fb4b0c0b3bc6cd90b0a1fc
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Wed May 24 12:31:28 2023 +0200
refactor: split `CommonSubexprEliminate::try_optimize` (#6348)
Having a single method w/ all optimization algorithms lets rustc+LLVM
pick the largest possible stack size over all branches when compiled
under the `dev` profile (e.g. for tests). So even if there's a deeply
nested projection (e.g. for `tpcds_logical_q64`), we pay for the more
complex aggregation optimization with every recursion.
Splitting the method into sub-methods seems to fix this.
Fixes #6277.
---
.../optimizer/src/common_subexpr_eliminate.rs | 460 +++++++++++----------
1 file changed, 246 insertions(+), 214 deletions(-)
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 6989ca5352..0f63ecc2cc 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -63,7 +63,7 @@ impl CommonSubexprEliminate {
&self,
exprs_list: &[&[Expr]],
arrays_list: &[&[Vec<(usize, String)>]],
- expr_set: &mut ExprSet,
+ expr_set: &ExprSet,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Vec<Vec<Expr>>> {
exprs_list
@@ -87,7 +87,7 @@ impl CommonSubexprEliminate {
exprs_list: &[&[Expr]],
arrays_list: &[&[Vec<(usize, String)>]],
input: &LogicalPlan,
- expr_set: &mut ExprSet,
+ expr_set: &ExprSet,
config: &dyn OptimizerConfig,
) -> Result<(Vec<Vec<Expr>>, LogicalPlan)> {
let mut affected_id = BTreeSet::<Identifier>::new();
@@ -104,222 +104,253 @@ impl CommonSubexprEliminate {
Ok((rewrite_exprs, new_input))
}
-}
-impl OptimizerRule for CommonSubexprEliminate {
- fn try_optimize(
+ fn try_optimize_projection(
&self,
- plan: &LogicalPlan,
+ projection: &Projection,
config: &dyn OptimizerConfig,
- ) -> Result<Option<LogicalPlan>> {
+ ) -> Result<LogicalPlan> {
+ let Projection {
+ expr,
+ input,
+ schema,
+ ..
+ } = projection;
+ let input_schema = Arc::clone(input.schema());
let mut expr_set = ExprSet::new();
+ let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
- let original_schema = plan.schema().clone();
- let optimized_plan = match plan {
- LogicalPlan::Projection(Projection {
- expr,
- input,
- schema,
- ..
- }) => {
- let input_schema = Arc::clone(input.schema());
- let arrays =
- to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
-
- let (mut new_expr, new_input) =
- self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?;
-
- Some(LogicalPlan::Projection(Projection::try_new_with_schema(
- pop_expr(&mut new_expr)?,
- Arc::new(new_input),
- schema.clone(),
- )?))
- }
- LogicalPlan::Filter(filter) => {
- let predicate = &filter.predicate;
- let input_schema = Arc::clone(filter.input.schema());
- let mut id_array = vec![];
- expr_to_identifier(
- predicate,
- &mut expr_set,
- &mut id_array,
- input_schema,
- ExprMask::Normal,
- )?;
-
- let (mut new_expr, new_input) = self.rewrite_expr(
- &[&[predicate.clone()]],
- &[&[id_array]],
- &filter.input,
- &mut expr_set,
- config,
- )?;
-
- if let Some(predicate) = pop_expr(&mut new_expr)?.pop() {
- Some(LogicalPlan::Filter(Filter::try_new(
- predicate,
- Arc::new(new_input),
- )?))
- } else {
- return Err(DataFusionError::Internal(
- "Failed to pop predicate expr".to_string(),
- ));
+ let (mut new_expr, new_input) =
+ self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?;
+
+ Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
+ pop_expr(&mut new_expr)?,
+ Arc::new(new_input),
+ schema.clone(),
+ )?))
+ }
+
+ fn try_optimize_filter(
+ &self,
+ filter: &Filter,
+ config: &dyn OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ let mut expr_set = ExprSet::new();
+ let predicate = &filter.predicate;
+ let input_schema = Arc::clone(filter.input.schema());
+ let mut id_array = vec![];
+ expr_to_identifier(
+ predicate,
+ &mut expr_set,
+ &mut id_array,
+ input_schema,
+ ExprMask::Normal,
+ )?;
+
+ let (mut new_expr, new_input) = self.rewrite_expr(
+ &[&[predicate.clone()]],
+ &[&[id_array]],
+ &filter.input,
+ &expr_set,
+ config,
+ )?;
+
+ if let Some(predicate) = pop_expr(&mut new_expr)?.pop() {
+ Ok(LogicalPlan::Filter(Filter::try_new(
+ predicate,
+ Arc::new(new_input),
+ )?))
+ } else {
+ Err(DataFusionError::Internal(
+ "Failed to pop predicate expr".to_string(),
+ ))
+ }
+ }
+
+ fn try_optimize_window(
+ &self,
+ window: &Window,
+ config: &dyn OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ let Window {
+ input,
+ window_expr,
+ schema,
+ } = window;
+ let mut expr_set = ExprSet::new();
+
+ let input_schema = Arc::clone(input.schema());
+ let arrays =
+ to_arrays(window_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+
+ let (mut new_expr, new_input) =
+ self.rewrite_expr(&[window_expr], &[&arrays], input, &expr_set, config)?;
+
+ Ok(LogicalPlan::Window(Window {
+ input: Arc::new(new_input),
+ window_expr: pop_expr(&mut new_expr)?,
+ schema: schema.clone(),
+ }))
+ }
+
+ fn try_optimize_aggregate(
+ &self,
+ aggregate: &Aggregate,
+ config: &dyn OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ let Aggregate {
+ group_expr,
+ aggr_expr,
+ input,
+ schema,
+ ..
+ } = aggregate;
+ let mut expr_set = ExprSet::new();
+
+ // rewrite inputs
+ let input_schema = Arc::clone(input.schema());
+ let group_arrays = to_arrays(
+ group_expr,
+ Arc::clone(&input_schema),
+ &mut expr_set,
+ ExprMask::Normal,
+ )?;
+ let aggr_arrays =
+ to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+
+ let (mut new_expr, new_input) = self.rewrite_expr(
+ &[group_expr, aggr_expr],
+ &[&group_arrays, &aggr_arrays],
+ input,
+ &expr_set,
+ config,
+ )?;
+ // note the reversed pop order.
+ let new_aggr_expr = pop_expr(&mut new_expr)?;
+ let new_group_expr = pop_expr(&mut new_expr)?;
+
+ // create potential projection on top
+ let mut expr_set = ExprSet::new();
+ let new_input_schema = Arc::clone(new_input.schema());
+ let aggr_arrays = to_arrays(
+ &new_aggr_expr,
+ new_input_schema.clone(),
+ &mut expr_set,
+ ExprMask::NormalAndAggregates,
+ )?;
+ let mut affected_id = BTreeSet::<Identifier>::new();
+ let mut rewritten = self.rewrite_exprs_list(
+ &[&new_aggr_expr],
+ &[&aggr_arrays],
+ &expr_set,
+ &mut affected_id,
+ )?;
+ let rewritten = pop_expr(&mut rewritten)?;
+
+ if affected_id.is_empty() {
+ Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
+ Arc::new(new_input),
+ new_group_expr,
+ new_aggr_expr,
+ schema.clone(),
+ )?))
+ } else {
+ let mut agg_exprs = vec![];
+
+ for id in affected_id {
+ match expr_set.get(&id) {
+ Some((expr, _, _)) => {
+ // todo: check `nullable`
+ agg_exprs.push(expr.clone().alias(&id));
+ }
+ _ => {
+ return Err(DataFusionError::Internal(
+ "expr_set invalid state".to_string(),
+ ));
+ }
}
}
- LogicalPlan::Window(Window {
- input,
- window_expr,
- schema,
- }) => {
- let input_schema = Arc::clone(input.schema());
- let arrays = to_arrays(
- window_expr,
- input_schema,
- &mut expr_set,
- ExprMask::Normal,
- )?;
-
- let (mut new_expr, new_input) = self.rewrite_expr(
- &[window_expr],
- &[&arrays],
- input,
- &mut expr_set,
- config,
- )?;
-
- Some(LogicalPlan::Window(Window {
- input: Arc::new(new_input),
- window_expr: pop_expr(&mut new_expr)?,
- schema: schema.clone(),
- }))
+
+ let mut proj_exprs = vec![];
+ for expr in &new_group_expr {
+ let out_col: Column =
+ expr.to_field(&new_input_schema)?.qualified_column();
+ proj_exprs.push(Expr::Column(out_col));
}
- LogicalPlan::Aggregate(Aggregate {
- group_expr,
- aggr_expr,
- input,
- schema,
- ..
- }) => {
- // rewrite inputs
- let input_schema = Arc::clone(input.schema());
- let group_arrays = to_arrays(
- group_expr,
- Arc::clone(&input_schema),
- &mut expr_set,
- ExprMask::Normal,
- )?;
- let aggr_arrays =
- to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?;
-
- let (mut new_expr, new_input) = self.rewrite_expr(
- &[group_expr, aggr_expr],
- &[&group_arrays, &aggr_arrays],
- input,
- &mut expr_set,
- config,
- )?;
- // note the reversed pop order.
- let new_aggr_expr = pop_expr(&mut new_expr)?;
- let new_group_expr = pop_expr(&mut new_expr)?;
-
- // create potential projection on top
- let mut expr_set = ExprSet::new();
- let new_input_schema = Arc::clone(new_input.schema());
- let aggr_arrays = to_arrays(
- &new_aggr_expr,
- new_input_schema.clone(),
- &mut expr_set,
- ExprMask::NormalAndAggregates,
- )?;
- let mut affected_id = BTreeSet::<Identifier>::new();
- let mut rewritten = self.rewrite_exprs_list(
- &[&new_aggr_expr],
- &[&aggr_arrays],
- &mut expr_set,
- &mut affected_id,
- )?;
- let rewritten = pop_expr(&mut rewritten)?;
-
- if affected_id.is_empty() {
- Some(LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
- Arc::new(new_input),
- new_group_expr,
- new_aggr_expr,
- schema.clone(),
- )?))
- } else {
- let mut agg_exprs = vec![];
-
- for id in affected_id {
- match expr_set.get(&id) {
- Some((expr, _, _)) => {
- // todo: check `nullable`
- agg_exprs.push(expr.clone().alias(&id));
- }
- _ => {
- return Err(DataFusionError::Internal(
- "expr_set invalid state".to_string(),
- ));
- }
- }
+ for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) {
+ if expr_rewritten == expr_orig {
+ if let Expr::Alias(expr, name) = expr_rewritten {
+ agg_exprs.push(expr.alias(&name));
+ proj_exprs.push(Expr::Column(Column::from_name(name)));
+ } else {
+ let id =
+ ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten);
+ let out_name =
+ expr_rewritten.to_field(&new_input_schema)?.qualified_name();
+ agg_exprs.push(expr_rewritten.alias(&id));
+ proj_exprs
+ .push(Expr::Column(Column::from_name(id)).alias(out_name));
}
+ } else {
+ proj_exprs.push(expr_rewritten);
+ }
+ }
- let mut proj_exprs = vec![];
- for expr in &new_group_expr {
- let out_col: Column =
- expr.to_field(&new_input_schema)?.qualified_column();
- proj_exprs.push(Expr::Column(out_col));
- }
- for (expr_rewritten, expr_orig) in
- rewritten.into_iter().zip(new_aggr_expr)
- {
- if expr_rewritten == expr_orig {
- if let Expr::Alias(expr, name) = expr_rewritten {
- agg_exprs.push(expr.alias(&name));
- proj_exprs.push(Expr::Column(Column::from_name(name)));
- } else {
- let id = ExprIdentifierVisitor::<'static>::desc_expr(
- &expr_rewritten,
- );
- let out_name = expr_rewritten
- .to_field(&new_input_schema)?
- .qualified_name();
- agg_exprs.push(expr_rewritten.alias(&id));
- proj_exprs.push(
- Expr::Column(Column::from_name(id)).alias(out_name),
- );
- }
- } else {
- proj_exprs.push(expr_rewritten);
- }
- }
+ let agg = LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::new(new_input),
+ new_group_expr,
+ agg_exprs,
+ )?);
- let agg = LogicalPlan::Aggregate(Aggregate::try_new(
- Arc::new(new_input),
- new_group_expr,
- agg_exprs,
- )?);
+ Ok(LogicalPlan::Projection(Projection::try_new(
+ proj_exprs,
+ Arc::new(agg),
+ )?))
+ }
+ }
- Some(LogicalPlan::Projection(Projection::try_new(
- proj_exprs,
- Arc::new(agg),
- )?))
- }
+ fn try_optimize_sort(
+ &self,
+ sort: &Sort,
+ config: &dyn OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ let Sort { expr, input, fetch } = sort;
+ let mut expr_set = ExprSet::new();
+
+ let input_schema = Arc::clone(input.schema());
+ let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
+
+ let (mut new_expr, new_input) =
+ self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?;
+
+ Ok(LogicalPlan::Sort(Sort {
+ expr: pop_expr(&mut new_expr)?,
+ input: Arc::new(new_input),
+ fetch: *fetch,
+ }))
+ }
+}
+
+impl OptimizerRule for CommonSubexprEliminate {
+ fn try_optimize(
+ &self,
+ plan: &LogicalPlan,
+ config: &dyn OptimizerConfig,
+ ) -> Result<Option<LogicalPlan>> {
+ let optimized_plan = match plan {
+ LogicalPlan::Projection(projection) => {
+ Some(self.try_optimize_projection(projection, config)?)
}
- LogicalPlan::Sort(Sort { expr, input, fetch }) => {
- let input_schema = Arc::clone(input.schema());
- let arrays =
- to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
-
- let (mut new_expr, new_input) =
- self.rewrite_expr(&[expr], &[&arrays], input, &mut expr_set, config)?;
-
- Some(LogicalPlan::Sort(Sort {
- expr: pop_expr(&mut new_expr)?,
- input: Arc::new(new_input),
- fetch: *fetch,
- }))
+ LogicalPlan::Filter(filter) => {
+ Some(self.try_optimize_filter(filter, config)?)
+ }
+ LogicalPlan::Window(window) => {
+ Some(self.try_optimize_window(window, config)?)
+ }
+ LogicalPlan::Aggregate(aggregate) => {
+ Some(self.try_optimize_aggregate(aggregate, config)?)
}
+ LogicalPlan::Sort(sort) => Some(self.try_optimize_sort(sort, config)?),
LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Repartition(_)
@@ -345,6 +376,7 @@ impl OptimizerRule for CommonSubexprEliminate {
}
};
+ let original_schema = plan.schema().clone();
match optimized_plan {
Some(optimized_plan) if optimized_plan.schema() != &original_schema => {
// add an additional projection if the output schema changed.
@@ -634,7 +666,7 @@ fn expr_to_identifier(
/// the corresponding temporary column name. That column contains the
/// evaluate result of replaced expression.
struct CommonSubexprRewriter<'a> {
- expr_set: &'a mut ExprSet,
+ expr_set: &'a ExprSet,
id_array: &'a [(usize, Identifier)],
/// Which identifier is replaced.
affected_id: &'a mut BTreeSet<Identifier>,
@@ -717,7 +749,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
fn replace_common_expr(
expr: Expr,
id_array: &[(usize, Identifier)],
- expr_set: &mut ExprSet,
+ expr_set: &ExprSet,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Expr> {
expr.rewrite(&mut CommonSubexprRewriter {
@@ -783,14 +815,14 @@ mod test {
)?;
let expected = vec![
- (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
- (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
- (4, ""),
- (3, "a + Int32(1)Int32(1)a"),
- (1, ""),
- (2, ""),
- (6, ""),
- (5, ""),
+ (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
+ (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"),
+ (4, ""),
+ (3, "a + Int32(1)Int32(1)a"),
+ (1, ""),
+ (2, ""),
+ (6, ""),
+ (5, ""),
(8, "")
]
.into_iter()