You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/08/13 14:20:28 UTC

[GitHub] [arrow] alamb commented on a change in pull request #7880: ARROW-9619: [Rust] [DataFusion] Add predicate push-down

alamb commented on a change in pull request #7880:
URL: https://github.com/apache/arrow/pull/7880#discussion_r469914297



##########
File path: rust/datafusion/src/optimizer/utils.rs
##########
@@ -183,6 +183,162 @@ fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
     }
 }
 
+/// returns all expressions in the logical plan.
+pub fn expressions(plan: &LogicalPlan) -> Vec<Expr> {
+    match plan {
+        LogicalPlan::Projection { expr, .. } => expr.clone(),
+        LogicalPlan::Selection { expr, .. } => vec![expr.clone()],
+        LogicalPlan::Aggregate {
+            group_expr,
+            aggr_expr,
+            ..
+        } => {
+            let mut result = group_expr.clone();
+            result.extend(aggr_expr.clone());
+            result
+        }
+        LogicalPlan::Sort { expr, .. } => expr.clone(),
+        // plans without expressions
+        LogicalPlan::TableScan { .. }
+        | LogicalPlan::InMemoryScan { .. }
+        | LogicalPlan::ParquetScan { .. }
+        | LogicalPlan::CsvScan { .. }
+        | LogicalPlan::EmptyRelation { .. }
+        | LogicalPlan::Limit { .. }
+        | LogicalPlan::CreateExternalTable { .. } => vec![],
+    }
+}
+
+/// returns all inputs in the logical plan
+pub fn inputs(plan: &LogicalPlan) -> Vec<&LogicalPlan> {
+    match plan {
+        LogicalPlan::Projection { input, .. } => vec![input],
+        LogicalPlan::Selection { input, .. } => vec![input],
+        LogicalPlan::Aggregate { input, .. } => vec![input],
+        LogicalPlan::Sort { input, .. } => vec![input],
+        LogicalPlan::Limit { input, .. } => vec![input],
+        // plans without inputs
+        LogicalPlan::TableScan { .. }
+        | LogicalPlan::InMemoryScan { .. }
+        | LogicalPlan::ParquetScan { .. }
+        | LogicalPlan::CsvScan { .. }
+        | LogicalPlan::EmptyRelation { .. }
+        | LogicalPlan::CreateExternalTable { .. } => vec![],
+    }
+}
+
+/// Returns a new logical plan based on the original one with inputs and expressions replaced
+pub fn from_plan(
+    plan: &LogicalPlan,
+    expr: &Vec<Expr>,
+    inputs: &Vec<LogicalPlan>,
+) -> Result<LogicalPlan> {
+    match plan {
+        LogicalPlan::Projection { schema, .. } => Ok(LogicalPlan::Projection {
+            expr: expr.clone(),
+            input: Box::new(inputs[0].clone()),
+            schema: schema.clone(),
+        }),
+        LogicalPlan::Selection { .. } => Ok(LogicalPlan::Selection {
+            expr: expr[0].clone(),
+            input: Box::new(inputs[0].clone()),
+        }),
+        LogicalPlan::Aggregate {
+            group_expr, schema, ..
+        } => Ok(LogicalPlan::Aggregate {
+            group_expr: expr[0..group_expr.len()].to_vec(),
+            aggr_expr: expr[group_expr.len()..].to_vec(),
+            input: Box::new(inputs[0].clone()),
+            schema: schema.clone(),
+        }),
+        LogicalPlan::Sort { .. } => Ok(LogicalPlan::Sort {
+            expr: expr.clone(),
+            input: Box::new(inputs[0].clone()),
+        }),
+        LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit {
+            n: *n,
+            input: Box::new(inputs[0].clone()),
+        }),
+        LogicalPlan::EmptyRelation { .. }
+        | LogicalPlan::TableScan { .. }
+        | LogicalPlan::InMemoryScan { .. }
+        | LogicalPlan::ParquetScan { .. }
+        | LogicalPlan::CsvScan { .. }
+        | LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()),
+    }
+}
+
+/// Returns all expressions composing the expression.
+/// E.g. if the expression is "(a + 1) + 1", it returns ["a + 1", "1"] (as Expr objects)
+pub fn expr_expressions(expr: &Expr) -> Result<Vec<&Expr>> {
+    match expr {
+        Expr::BinaryExpr { left, right, .. } => Ok(vec![left, right]),
+        Expr::IsNull(e) => Ok(vec![e]),
+        Expr::IsNotNull(e) => Ok(vec![e]),
+        Expr::ScalarFunction { args, .. } => Ok(args.iter().collect()),
+        Expr::AggregateFunction { args, .. } => Ok(args.iter().collect()),
+        Expr::Cast { expr, .. } => Ok(vec![expr]),
+        Expr::Column(_) => Ok(vec![]),
+        Expr::Alias(expr, ..) => Ok(vec![expr]),
+        Expr::Literal(_) => Ok(vec![]),
+        Expr::Not(expr) => Ok(vec![expr]),
+        Expr::Sort { expr, .. } => Ok(vec![expr]),
+        Expr::Wildcard { .. } => Err(ExecutionError::General(
+            "Wildcard expressions are not valid in a logical query plan".to_owned(),
+        )),
+        Expr::Nested(expr) => Ok(vec![expr]),
+    }
+}
+
+/// returns a new expression where the expressions in expr are replaced by the ones in `expr`.
+/// This is used in conjunction with ``expr_expressions`` to re-write expressions.
+pub fn from_expression(expr: &Expr, expressions: &Vec<Expr>) -> Result<Expr> {

Review comment:
       Calling this `rewrite_expression` that might make it clearer what this function was doing.

##########
File path: rust/datafusion/src/optimizer/utils.rs
##########
@@ -183,6 +183,162 @@ fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
     }
 }
 
+/// returns all expressions in the logical plan.
+pub fn expressions(plan: &LogicalPlan) -> Vec<Expr> {
+    match plan {
+        LogicalPlan::Projection { expr, .. } => expr.clone(),
+        LogicalPlan::Selection { expr, .. } => vec![expr.clone()],
+        LogicalPlan::Aggregate {
+            group_expr,
+            aggr_expr,
+            ..
+        } => {
+            let mut result = group_expr.clone();
+            result.extend(aggr_expr.clone());
+            result
+        }
+        LogicalPlan::Sort { expr, .. } => expr.clone(),
+        // plans without expressions
+        LogicalPlan::TableScan { .. }
+        | LogicalPlan::InMemoryScan { .. }
+        | LogicalPlan::ParquetScan { .. }
+        | LogicalPlan::CsvScan { .. }
+        | LogicalPlan::EmptyRelation { .. }
+        | LogicalPlan::Limit { .. }
+        | LogicalPlan::CreateExternalTable { .. } => vec![],
+    }
+}
+
+/// returns all inputs in the logical plan
+pub fn inputs(plan: &LogicalPlan) -> Vec<&LogicalPlan> {
+    match plan {
+        LogicalPlan::Projection { input, .. } => vec![input],
+        LogicalPlan::Selection { input, .. } => vec![input],
+        LogicalPlan::Aggregate { input, .. } => vec![input],
+        LogicalPlan::Sort { input, .. } => vec![input],
+        LogicalPlan::Limit { input, .. } => vec![input],
+        // plans without inputs
+        LogicalPlan::TableScan { .. }
+        | LogicalPlan::InMemoryScan { .. }
+        | LogicalPlan::ParquetScan { .. }
+        | LogicalPlan::CsvScan { .. }
+        | LogicalPlan::EmptyRelation { .. }
+        | LogicalPlan::CreateExternalTable { .. } => vec![],
+    }
+}
+
+/// Returns a new logical plan based on the original one with inputs and expressions replaced
+pub fn from_plan(
+    plan: &LogicalPlan,
+    expr: &Vec<Expr>,
+    inputs: &Vec<LogicalPlan>,
+) -> Result<LogicalPlan> {
+    match plan {
+        LogicalPlan::Projection { schema, .. } => Ok(LogicalPlan::Projection {
+            expr: expr.clone(),
+            input: Box::new(inputs[0].clone()),
+            schema: schema.clone(),
+        }),
+        LogicalPlan::Selection { .. } => Ok(LogicalPlan::Selection {
+            expr: expr[0].clone(),
+            input: Box::new(inputs[0].clone()),
+        }),
+        LogicalPlan::Aggregate {
+            group_expr, schema, ..
+        } => Ok(LogicalPlan::Aggregate {
+            group_expr: expr[0..group_expr.len()].to_vec(),
+            aggr_expr: expr[group_expr.len()..].to_vec(),
+            input: Box::new(inputs[0].clone()),
+            schema: schema.clone(),
+        }),
+        LogicalPlan::Sort { .. } => Ok(LogicalPlan::Sort {
+            expr: expr.clone(),
+            input: Box::new(inputs[0].clone()),
+        }),
+        LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit {
+            n: *n,
+            input: Box::new(inputs[0].clone()),
+        }),
+        LogicalPlan::EmptyRelation { .. }
+        | LogicalPlan::TableScan { .. }
+        | LogicalPlan::InMemoryScan { .. }
+        | LogicalPlan::ParquetScan { .. }
+        | LogicalPlan::CsvScan { .. }
+        | LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()),
+    }
+}
+
+/// Returns all expressions composing the expression.
+/// E.g. if the expression is "(a + 1) + 1", it returns ["a + 1", "1"] (as Expr objects)
+pub fn expr_expressions(expr: &Expr) -> Result<Vec<&Expr>> {

Review comment:
       ```suggestion
   pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<&Expr>> {
   ```

##########
File path: rust/datafusion/src/optimizer/utils.rs
##########
@@ -183,6 +183,162 @@ fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
     }
 }
 
+/// returns all expressions in the logical plan.
+pub fn expressions(plan: &LogicalPlan) -> Vec<Expr> {
+    match plan {
+        LogicalPlan::Projection { expr, .. } => expr.clone(),
+        LogicalPlan::Selection { expr, .. } => vec![expr.clone()],
+        LogicalPlan::Aggregate {
+            group_expr,
+            aggr_expr,
+            ..
+        } => {
+            let mut result = group_expr.clone();
+            result.extend(aggr_expr.clone());
+            result
+        }
+        LogicalPlan::Sort { expr, .. } => expr.clone(),
+        // plans without expressions
+        LogicalPlan::TableScan { .. }
+        | LogicalPlan::InMemoryScan { .. }
+        | LogicalPlan::ParquetScan { .. }
+        | LogicalPlan::CsvScan { .. }
+        | LogicalPlan::EmptyRelation { .. }
+        | LogicalPlan::Limit { .. }
+        | LogicalPlan::CreateExternalTable { .. } => vec![],
+    }
+}
+
+/// returns all inputs in the logical plan
+pub fn inputs(plan: &LogicalPlan) -> Vec<&LogicalPlan> {
+    match plan {
+        LogicalPlan::Projection { input, .. } => vec![input],
+        LogicalPlan::Selection { input, .. } => vec![input],
+        LogicalPlan::Aggregate { input, .. } => vec![input],
+        LogicalPlan::Sort { input, .. } => vec![input],
+        LogicalPlan::Limit { input, .. } => vec![input],
+        // plans without inputs
+        LogicalPlan::TableScan { .. }
+        | LogicalPlan::InMemoryScan { .. }
+        | LogicalPlan::ParquetScan { .. }
+        | LogicalPlan::CsvScan { .. }
+        | LogicalPlan::EmptyRelation { .. }
+        | LogicalPlan::CreateExternalTable { .. } => vec![],
+    }
+}
+
+/// Returns a new logical plan based on the original one with inputs and expressions replaced
+pub fn from_plan(
+    plan: &LogicalPlan,
+    expr: &Vec<Expr>,
+    inputs: &Vec<LogicalPlan>,
+) -> Result<LogicalPlan> {
+    match plan {
+        LogicalPlan::Projection { schema, .. } => Ok(LogicalPlan::Projection {
+            expr: expr.clone(),
+            input: Box::new(inputs[0].clone()),
+            schema: schema.clone(),
+        }),
+        LogicalPlan::Selection { .. } => Ok(LogicalPlan::Selection {
+            expr: expr[0].clone(),
+            input: Box::new(inputs[0].clone()),
+        }),
+        LogicalPlan::Aggregate {
+            group_expr, schema, ..
+        } => Ok(LogicalPlan::Aggregate {
+            group_expr: expr[0..group_expr.len()].to_vec(),
+            aggr_expr: expr[group_expr.len()..].to_vec(),
+            input: Box::new(inputs[0].clone()),
+            schema: schema.clone(),
+        }),
+        LogicalPlan::Sort { .. } => Ok(LogicalPlan::Sort {
+            expr: expr.clone(),
+            input: Box::new(inputs[0].clone()),
+        }),
+        LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit {
+            n: *n,
+            input: Box::new(inputs[0].clone()),
+        }),
+        LogicalPlan::EmptyRelation { .. }
+        | LogicalPlan::TableScan { .. }
+        | LogicalPlan::InMemoryScan { .. }
+        | LogicalPlan::ParquetScan { .. }
+        | LogicalPlan::CsvScan { .. }
+        | LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()),
+    }
+}
+
+/// Returns all expressions composing the expression.

Review comment:
       ```suggestion
   /// Returns direct children `Expression`s of `expr`.
   ```

##########
File path: rust/datafusion/src/optimizer/utils.rs
##########
@@ -183,6 +183,162 @@ fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
     }
 }
 
+/// returns all expressions in the logical plan.

Review comment:
       ```suggestion
   /// returns all expressions (non-recursively) in the current logical plan node.
   ```

##########
File path: rust/datafusion/src/optimizer/filter_push_down.rs
##########
@@ -0,0 +1,467 @@
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan
+
+use crate::error::Result;
+use crate::logicalplan::Expr;
+use crate::logicalplan::LogicalPlan;
+use crate::optimizer::optimizer::OptimizerRule;
+use crate::optimizer::utils;
+use std::collections::{HashMap, HashSet};
+
+/// Filter Push Down optimizer rule pushes filter clauses down the plan
+///
+/// This optimization looks for the maximum depth of each column in the plan where a filter can be applied and
+/// re-writes the plan with filters on those locations.
+/// It performs two passes on the plan:
+/// 1. identify filters, which columns they use, and projections along the path
+/// 2. move filters down, re-writing the expressions using the projections
+pub struct FilterPushDown {}
+
+impl OptimizerRule for FilterPushDown {
+    fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
+        let (break_points, selections, projections) = analyze_plan(plan, 0)?;
+
+        // compute max depth for each of the columns
+        let mut breaks: HashMap<String, usize> = HashMap::new();
+        for (key, depth) in break_points {
+            match breaks.get(&key) {
+                Some(current_depth) => {
+                    if depth > *current_depth {
+                        breaks.insert(key, depth);
+                    }
+                }
+                None => {
+                    breaks.insert(key, depth);
+                }
+            }
+        }
+
+        // construct optimized position of each of the new selections
+        let mut new_selections: HashMap<usize, Expr> = HashMap::new();
+        for (selection_depth, expr) in selections {
+            let mut columns: HashSet<String> = HashSet::new();
+            utils::expr_to_column_names(&expr, &mut columns)?;
+
+            // compute the depths of each of the observed columns and the respective maximum
+            let depth = columns
+                .iter()
+                .filter_map(|column| breaks.get(column))
+                .max_by_key(|depth| **depth);
+
+            let new_depth = match depth {
+                None => selection_depth,
+                Some(d) => *d,
+            };
+
+            // re-write the new selections based on all projections that it crossed.
+            // E.g. in `Selection: #b\n  Projection: #a > 1 as b`, we can swap them, but the selection must be "#a > 1"
+            let mut new_expression = expr.clone();
+            for depth_i in selection_depth..new_depth {
+                if let Some(projection) = projections.get(&depth_i) {
+                    new_expression = rewrite(&new_expression, projection)?;
+                }
+            }
+
+            new_selections.insert(new_depth, new_expression);
+        }
+
+        optimize_plan(plan, &new_selections, 0)
+    }
+}
+
+/// Recursively transverses the logical plan looking for depths that break filter pushdown
+/// Returns a tuple:
+/// 0: map "column -> depth" of the depth that each column is found up to.
+/// 1: map "depth -> filter expression"
+/// 2: map "depth -> projection"
+fn analyze_plan(
+    plan: &LogicalPlan,
+    depth: usize,
+) -> Result<(
+    HashMap<String, usize>,
+    HashMap<usize, Expr>,
+    HashMap<usize, HashMap<String, Expr>>,
+)> {
+    match plan {
+        LogicalPlan::Selection { input, expr } => {
+            let mut result = analyze_plan(&input, depth + 1)?;
+            result.1.insert(depth, expr.clone());
+            Ok(result)
+        }
+        LogicalPlan::Projection {
+            input,
+            expr,
+            schema,
+        } => {
+            let mut result = analyze_plan(&input, depth + 1)?;
+
+            // collect projection.
+            let mut projection = HashMap::new();
+            schema.fields().iter().enumerate().for_each(|(i, field)| {
+                // strip alias, as they should not be part of selections
+                let expr = match &expr[i] {
+                    Expr::Alias(expr, _) => expr.as_ref().clone(),
+                    expr => expr.clone(),
+                };
+
+                projection.insert(field.name().clone(), expr);
+            });
+            result.2.insert(depth, projection);
+            Ok(result)
+        }
+        LogicalPlan::Aggregate {
+            input, aggr_expr, ..
+        } => {
+            let mut result = analyze_plan(&input, depth + 1)?;
+
+            let mut accum = HashSet::new();
+            utils::exprlist_to_column_names(aggr_expr, &mut accum)?;
+
+            accum.iter().for_each(|x: &String| {
+                result.0.insert(x.clone(), depth);
+            });
+
+            Ok(result)
+        }
+        LogicalPlan::Sort { input, .. } => analyze_plan(&input, depth + 1),
+        LogicalPlan::Limit { input, .. } => {
+            let mut result = analyze_plan(&input, depth + 1)?;
+
+            // pick all fields in the input schema of this limit: none of them can be used after this limit.
+            input.schema().fields().iter().for_each(|x| {
+                result.0.insert(x.name().clone(), depth);
+            });
+            Ok(result)
+        }
+        // all other plans add breaks to all their columns to indicate that filters can't proceed further.
+        _ => {
+            let mut result = HashMap::new();
+            plan.schema().fields().iter().for_each(|x| {
+                result.insert(x.name().clone(), depth);
+            });
+            Ok((result, HashMap::new(), HashMap::new()))
+        }
+    }
+}
+
+impl FilterPushDown {
+    #[allow(missing_docs)]
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+/// Returns a re-written logical plan where all old filters are removed and the new ones are added.
+fn optimize_plan(
+    plan: &LogicalPlan,
+    new_selections: &HashMap<usize, Expr>,
+    depth: usize,
+) -> Result<LogicalPlan> {
+    // optimize the plan recursively:
+    let new_plan = match plan {
+        LogicalPlan::Selection { input, .. } => {
+            // ignore old selections
+            Ok(optimize_plan(&input, new_selections, depth + 1)?)
+        }
+        _ => {
+            // all other nodes are copied, optimizing recursively.
+            let expr = utils::expressions(plan);
+
+            let inputs = utils::inputs(plan);
+            let new_inputs = inputs
+                .iter()
+                .map(|plan| optimize_plan(plan, new_selections, depth + 1))
+                .collect::<Result<Vec<_>>>()?;
+
+            utils::from_plan(plan, &expr, &new_inputs)
+        }
+    }?;
+
+    // if a new selection is to be applied, apply it
+    if let Some(expr) = new_selections.get(&depth) {
+        return Ok(LogicalPlan::Selection {
+            expr: expr.clone(),
+            input: Box::new(new_plan),
+        });
+    } else {
+        Ok(new_plan)
+    }
+}
+
+/// replaces columns by its name on the projection.
+fn rewrite(expr: &Expr, projection: &HashMap<String, Expr>) -> Result<Expr> {
+    let expressions = utils::expr_expressions(&expr)?;
+
+    let expressions = expressions
+        .iter()
+        .map(|e| rewrite(e, &projection))
+        .collect::<Result<Vec<_>>>()?;
+
+    match expr {
+        Expr::Column(name) => {
+            if let Some(expr) = projection.get(name) {
+                return Ok(expr.clone());
+            }
+        }
+        _ => {}
+    }
+
+    utils::from_expression(&expr, &expressions)
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::logicalplan::col;
+    use crate::logicalplan::ScalarValue;
+    use crate::logicalplan::{aggregate_expr, lit, Expr, LogicalPlanBuilder, Operator};
+    use crate::test::*;
+    use arrow::datatypes::DataType;
+
+    fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
+        let mut rule = FilterPushDown::new();
+        let optimized_plan = rule.optimize(plan).expect("failed to optimize plan");
+        let formatted_plan = format!("{:?}", optimized_plan);
+        assert_eq!(formatted_plan, expected);
+    }
+
+    #[test]
+    fn filter_before_projection() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![col("a"), col("b")])?
+            .filter(col("a").eq(&Expr::Literal(ScalarValue::Int64(1))))?
+            .build()?;
+        // selection is before projection
+        let expected = "\
+            Projection: #a, #b\
+            \n  Selection: #a Eq Int64(1)\
+            \n    TableScan: test projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn filter_after_limit() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![col("a"), col("b")])?
+            .limit(10)?
+            .filter(col("a").eq(&Expr::Literal(ScalarValue::Int64(1))))?
+            .build()?;
+        // selection is before single projection
+        let expected = "\
+            Selection: #a Eq Int64(1)\
+            \n  Limit: 10\
+            \n    Projection: #a, #b\
+            \n      TableScan: test projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn filter_jump_2_plans() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![col("a"), col("b"), col("c")])?
+            .project(vec![col("c"), col("b")])?
+            .filter(col("a").eq(&Expr::Literal(ScalarValue::Int64(1))))?
+            .build()?;
+        // selection is before double projection
+        let expected = "\
+            Projection: #c, #b\
+            \n  Projection: #a, #b, #c\
+            \n    Selection: #a Eq Int64(1)\
+            \n      TableScan: test projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn filter_move_agg() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(&table_scan)
+            .aggregate(
+                vec![col("a")],
+                vec![aggregate_expr("SUM", col("b"), DataType::Int32)
+                    .alias("total_salary")],
+            )?
+            .filter(col("a").gt(&Expr::Literal(ScalarValue::Int64(10))))?
+            .build()?;
+        // selection of key aggregation is commutative
+        let expected = "\
+            Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS total_salary]]\
+            \n  Selection: #a Gt Int64(10)\
+            \n    TableScan: test projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn filter_keep_agg() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(&table_scan)
+            .aggregate(
+                vec![col("a")],
+                vec![aggregate_expr("SUM", col("b"), DataType::Int32).alias("b")],
+            )?
+            .filter(col("b").gt(&Expr::Literal(ScalarValue::Int64(10))))?
+            .build()?;
+        // selection of aggregate is after aggregation since they are non-commutative
+        let expected = "\
+            Selection: #b Gt Int64(10)\
+            \n  Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS b]]\
+            \n    TableScan: test projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
+    #[test]
+    fn alias() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![col("a").alias("b"), col("c")])?
+            .filter(col("b").eq(&Expr::Literal(ScalarValue::Int64(1))))?
+            .build()?;
+        // selection is before projection
+        let expected = "\
+            Projection: #a AS b, #c\
+            \n  Selection: #a Eq Int64(1)\
+            \n    TableScan: test projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    fn add(left: Expr, right: Expr) -> Expr {
+        Expr::BinaryExpr {
+            left: Box::new(left),
+            op: Operator::Plus,
+            right: Box::new(right),
+        }
+    }
+
+    fn multiply(left: Expr, right: Expr) -> Expr {
+        Expr::BinaryExpr {
+            left: Box::new(left),
+            op: Operator::Multiply,
+            right: Box::new(right),
+        }
+    }
+
+    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
+    #[test]
+    fn complex_expression() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![
+                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
+                col("c"),
+            ])?
+            .filter(col("b").eq(&Expr::Literal(ScalarValue::Int64(1))))?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{:?}", plan),
+            "\
+            Selection: #b Eq Int64(1)\
+            \n  Projection: #a Multiply Int32(2) Plus #c AS b, #c\
+            \n    TableScan: test projection=None"
+        );
+
+        // selection is before projection
+        let expected = "\
+            Projection: #a Multiply Int32(2) Plus #c AS b, #c\
+            \n  Selection: #a Multiply Int32(2) Plus #c Eq Int64(1)\
+            \n    TableScan: test projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
+    #[test]
+    fn complex_plan() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(&table_scan)
+            .project(vec![
+                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
+                col("c"),
+            ])?
+            // second projection where we rename columns, just to make it difficult
+            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
+            .filter(col("a").eq(&Expr::Literal(ScalarValue::Int64(1))))?
+            .build()?;
+
+        // not part of the test, just good to know:

Review comment:
       👍 

##########
File path: rust/datafusion/src/optimizer/type_coercion.rs
##########
@@ -43,138 +45,77 @@ impl<'a> TypeCoercionRule<'a> {
         Self { scalar_functions }
     }
 
-    /// Rewrite an expression list to include explicit CAST operations when required
-    fn rewrite_expr_list(&self, expr: &[Expr], schema: &Schema) -> Result<Vec<Expr>> {
-        Ok(expr
+    /// Rewrite an expression to include explicit CAST operations when required
+    fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+        let expressions = utils::expr_expressions(expr)?;
+
+        // recurse of the re-write
+        let mut expressions = expressions
             .iter()
             .map(|e| self.rewrite_expr(e, schema))
-            .collect::<Result<Vec<_>>>()?)
-    }
+            .collect::<Result<Vec<_>>>()?;
 
-    /// Rewrite an expression to include explicit CAST operations when required
-    fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result<Expr> {
+        // modify `expressions` by introducing casts when necessary
         match expr {
-            Expr::BinaryExpr { left, op, right } => {
-                let left = self.rewrite_expr(left, schema)?;
-                let right = self.rewrite_expr(right, schema)?;
-                let left_type = left.get_type(schema)?;
-                let right_type = right.get_type(schema)?;
-                if left_type == right_type {
-                    Ok(Expr::BinaryExpr {
-                        left: Box::new(left),
-                        op: op.clone(),
-                        right: Box::new(right),
-                    })
-                } else {
+            Expr::BinaryExpr { .. } => {
+                let left_type = expressions[0].get_type(schema)?;
+                let right_type = expressions[1].get_type(schema)?;
+                if left_type != right_type {
                     let super_type = utils::get_supertype(&left_type, &right_type)?;
-                    Ok(Expr::BinaryExpr {
-                        left: Box::new(left.cast_to(&super_type, schema)?),
-                        op: op.clone(),
-                        right: Box::new(right.cast_to(&super_type, schema)?),
-                    })
+
+                    expressions[0] = expressions[0].cast_to(&super_type, schema)?;
+                    expressions[1] = expressions[1].cast_to(&super_type, schema)?;
                 }
             }
-            Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(self.rewrite_expr(e, schema)?))),
-            Expr::IsNotNull(e) => {
-                Ok(Expr::IsNotNull(Box::new(self.rewrite_expr(e, schema)?)))
-            }
-            Expr::ScalarFunction {
-                name,
-                args,
-                return_type,
-            } => {
+            Expr::ScalarFunction { name, .. } => {
                 // cast the inputs of scalar functions to the appropriate type where possible
                 match self.scalar_functions.get(name) {
                     Some(func_meta) => {
-                        let mut func_args = Vec::with_capacity(args.len());
-                        for i in 0..args.len() {
+                        for i in 0..expressions.len() {
                             let field = &func_meta.args[i];
-                            let expr = self.rewrite_expr(&args[i], schema)?;
-                            let actual_type = expr.get_type(schema)?;
+                            let actual_type = expressions[i].get_type(schema)?;
                             let required_type = field.data_type();
-                            if &actual_type == required_type {
-                                func_args.push(expr)
-                            } else {
+                            if &actual_type != required_type {
                                 let super_type =
                                     utils::get_supertype(&actual_type, required_type)?;
-                                func_args.push(expr.cast_to(&super_type, schema)?);
-                            }
+                                expressions[i] =
+                                    expressions[i].cast_to(&super_type, schema)?
+                            };
                         }
-
-                        Ok(Expr::ScalarFunction {
-                            name: name.clone(),
-                            args: func_args,
-                            return_type: return_type.clone(),
-                        })
                     }
-                    _ => Err(ExecutionError::General(format!(
-                        "Invalid scalar function {}",
-                        name
-                    ))),
+                    _ => {
+                        return Err(ExecutionError::General(format!(
+                            "Invalid scalar function {}",
+                            name
+                        )))
+                    }
                 }
             }
-            Expr::AggregateFunction {
-                name,
-                args,
-                return_type,
-            } => Ok(Expr::AggregateFunction {
-                name: name.clone(),
-                args: args
-                    .iter()
-                    .map(|a| self.rewrite_expr(a, schema))
-                    .collect::<Result<Vec<_>>>()?,
-                return_type: return_type.clone(),
-            }),
-            Expr::Cast { .. } => Ok(expr.clone()),
-            Expr::Column(_) => Ok(expr.clone()),
-            Expr::Alias(expr, alias) => Ok(Expr::Alias(
-                Box::new(self.rewrite_expr(expr, schema)?),
-                alias.to_owned(),
-            )),
-            Expr::Literal(_) => Ok(expr.clone()),
-            Expr::Not(_) => Ok(expr.clone()),
-            Expr::Sort { .. } => Ok(expr.clone()),
-            Expr::Wildcard { .. } => Err(ExecutionError::General(
-                "Wildcard expressions are not valid in a logical query plan".to_owned(),
-            )),
-            Expr::Nested(e) => self.rewrite_expr(e, schema),
-        }
+            _ => {}
+        };
+        utils::from_expression(expr, &expressions)
     }
 }
 
 impl<'a> OptimizerRule for TypeCoercionRule<'a> {
     fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
-        match plan {
-            LogicalPlan::Projection { expr, input, .. } => {
-                LogicalPlanBuilder::from(&self.optimize(input)?)
-                    .project(self.rewrite_expr_list(expr, input.schema())?)?
-                    .build()
-            }
-            LogicalPlan::Selection { expr, input, .. } => {
-                LogicalPlanBuilder::from(&self.optimize(input)?)
-                    .filter(self.rewrite_expr(expr, input.schema())?)?
-                    .build()
-            }
-            LogicalPlan::Aggregate {
-                input,
-                group_expr,
-                aggr_expr,
-                ..
-            } => LogicalPlanBuilder::from(&self.optimize(input)?)
-                .aggregate(
-                    self.rewrite_expr_list(group_expr, input.schema())?,
-                    self.rewrite_expr_list(aggr_expr, input.schema())?,
-                )?
-                .build(),
-            LogicalPlan::TableScan { .. } => Ok(plan.clone()),
-            LogicalPlan::InMemoryScan { .. } => Ok(plan.clone()),
-            LogicalPlan::ParquetScan { .. } => Ok(plan.clone()),
-            LogicalPlan::CsvScan { .. } => Ok(plan.clone()),
-            LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()),
-            LogicalPlan::Limit { .. } => Ok(plan.clone()),
-            LogicalPlan::Sort { .. } => Ok(plan.clone()),
-            LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()),
-        }
+        let inputs = utils::inputs(plan);
+        let expressions = utils::expressions(plan);
+
+        // apply the optimization to all inputs of the plan
+        let new_inputs = inputs
+            .iter()
+            .map(|plan| self.optimize(*plan))
+            .collect::<Result<Vec<_>>>()?;
+        // re-write all expressions on this plan.
+        // This assumes a single input, [0]. It wont work for join, subqueries and union operations with more than one input.
+        // It is currently not an issue as we do not have any plan with more than one input.
+        let new_expressions = expressions

Review comment:
       Do you have to check for the case here where the `LogicalPlan` node has no inputs?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org