You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/30 19:51:08 UTC

[arrow-datafusion] branch master updated: reimplement `push_down_filter` to remove global-state (#4365)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 3fe542f2a reimplement `push_down_filter` to remove global-state (#4365)
3fe542f2a is described below

commit 3fe542f2afcd5360edc9abb7ad1e8243b560a6b2
Author: jakevin <ja...@gmail.com>
AuthorDate: Thu Dec 1 03:51:03 2022 +0800

    reimplement `push_down_filter` to remove global-state (#4365)
    
    * reimplement `push_down_filter`
    
    * add comment
    
    * fix union replace
    
    * fix but when meet same name but different qualifier, and add ut.c
    
    * add UT `filter_complex_agg` `test_union_different_schema`
    
    * fix regression for push_down_filter meet subquery-alias
    
    * polish comment
    
    * add UT confirm that avoid duplicate Filters
    
    * merge confirm UT
    
    * remove TODO
---
 benchmarks/expected-plans/q21.txt                  |    8 +-
 benchmarks/expected-plans/q7.txt                   |    8 +-
 datafusion/core/tests/sql/joins.rs                 |   13 +-
 datafusion/optimizer/src/lib.rs                    |    2 +-
 datafusion/optimizer/src/optimizer.rs              |    5 +-
 .../{filter_push_down.rs => push_down_filter.rs}   | 1196 ++++++++++----------
 datafusion/optimizer/tests/integration-test.rs     |   31 +-
 7 files changed, 604 insertions(+), 659 deletions(-)

diff --git a/benchmarks/expected-plans/q21.txt b/benchmarks/expected-plans/q21.txt
index 397e0a8d8..3ef6269de 100644
--- a/benchmarks/expected-plans/q21.txt
+++ b/benchmarks/expected-plans/q21.txt
@@ -7,8 +7,8 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST
             Inner Join: l1.l_orderkey = orders.o_orderkey
               Inner Join: supplier.s_suppkey = l1.l_suppkey
                 TableScan: supplier projection=[s_suppkey, s_name, s_nationkey]
-                Filter: l1.l_receiptdate > l1.l_commitdate
-                  SubqueryAlias: l1
+                SubqueryAlias: l1
+                  Filter: lineitem.l_receiptdate > lineitem.l_commitdate
                     TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate]
               Filter: orders.o_orderstatus = Utf8("F")
                 TableScan: orders projection=[o_orderkey, o_orderstatus]
@@ -16,6 +16,6 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST
               TableScan: nation projection=[n_nationkey, n_name]
           SubqueryAlias: l2
             TableScan: lineitem projection=[l_orderkey, l_suppkey]
-        Filter: l3.l_receiptdate > l3.l_commitdate
-          SubqueryAlias: l3
+        SubqueryAlias: l3
+          Filter: lineitem.l_receiptdate > lineitem.l_commitdate
             TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate]
\ No newline at end of file
diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt
index 74857c6f9..53deda1b8 100644
--- a/benchmarks/expected-plans/q7.txt
+++ b/benchmarks/expected-plans/q7.txt
@@ -14,9 +14,9 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST,
                         TableScan: lineitem projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate]
                     TableScan: orders projection=[o_orderkey, o_custkey]
                   TableScan: customer projection=[c_custkey, c_nationkey]
-                Filter: n1.n_name = Utf8("FRANCE") OR n1.n_name = Utf8("GERMANY")
-                  SubqueryAlias: n1
+                SubqueryAlias: n1
+                  Filter: nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY")
                     TableScan: nation projection=[n_nationkey, n_name]
-              Filter: n2.n_name = Utf8("GERMANY") OR n2.n_name = Utf8("FRANCE")
-                SubqueryAlias: n2
+              SubqueryAlias: n2
+                Filter: nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE")
                   TableScan: nation projection=[n_nationkey, n_name]
\ No newline at end of file
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 87fb594c7..7129fc7ed 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1636,15 +1636,14 @@ async fn reduce_left_join_3() -> Result<()> {
             "Explain [plan_type:Utf8, plan:Utf8]",
             "  Projection: t3.t1_id, t3.t1_name, t3.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
             "    Left Join: t3.t1_int = t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-            "      Filter: t3.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-            "        SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-            "          Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "      SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "        Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "          Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
             "            TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
-            "            Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-            "              TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "          Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "            TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
             "      TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-        ]
-            ;
+        ];
         let formatted = plan.display_indent_schema().to_string();
         let actual: Vec<&str> = formatted.trim().lines().collect();
         assert_eq!(
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index cdfe7fc9b..aba53a3d8 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -23,12 +23,12 @@ pub mod eliminate_filter;
 pub mod eliminate_limit;
 pub mod eliminate_outer_join;
 pub mod filter_null_join_keys;
-pub mod filter_push_down;
 pub mod inline_table_scan;
 pub mod limit_push_down;
 pub mod optimizer;
 pub mod projection_push_down;
 pub mod propagate_empty_relation;
+pub mod push_down_filter;
 pub mod scalar_subquery_to_join;
 pub mod simplify_expressions;
 pub mod single_distinct_to_groupby;
diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs
index 5340f4f80..19eeecb0f 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -25,11 +25,11 @@ use crate::eliminate_filter::EliminateFilter;
 use crate::eliminate_limit::EliminateLimit;
 use crate::eliminate_outer_join::EliminateOuterJoin;
 use crate::filter_null_join_keys::FilterNullJoinKeys;
-use crate::filter_push_down::FilterPushDown;
 use crate::inline_table_scan::InlineTableScan;
 use crate::limit_push_down::LimitPushDown;
 use crate::projection_push_down::ProjectionPushDown;
 use crate::propagate_empty_relation::PropagateEmptyRelation;
+use crate::push_down_filter::PushDownFilter;
 use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
 use crate::scalar_subquery_to_join::ScalarSubqueryToJoin;
 use crate::simplify_expressions::SimplifyExpressions;
@@ -184,8 +184,9 @@ impl Optimizer {
             rules.push(Arc::new(FilterNullJoinKeys::default()));
         }
         rules.push(Arc::new(EliminateOuterJoin::new()));
-        rules.push(Arc::new(FilterPushDown::new()));
+        // Filters can't be pushed down past Limits, we should do PushDownFilter after LimitPushDown
         rules.push(Arc::new(LimitPushDown::new()));
+        rules.push(Arc::new(PushDownFilter::new()));
         rules.push(Arc::new(SingleDistinctToGroupBy::new()));
 
         // The previous optimizations added expressions and projections,
diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/push_down_filter.rs
similarity index 72%
rename from datafusion/optimizer/src/filter_push_down.rs
rename to datafusion/optimizer/src/push_down_filter.rs
index 2f8a8a8b4..e59590df5 100644
--- a/datafusion/optimizer/src/filter_push_down.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -12,26 +12,25 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan
+//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan
 
+use crate::utils::conjunction;
 use crate::{utils, OptimizerConfig, OptimizerRule};
 use datafusion_common::{Column, DFSchema, DataFusionError, Result};
+use datafusion_expr::utils::exprlist_to_columns;
 use datafusion_expr::{
-    and, col,
-    expr::BinaryExpr,
+    and,
     expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
-    logical_plan::{
-        Aggregate, CrossJoin, Join, JoinType, Limit, LogicalPlan, Projection, TableScan,
-        Union,
-    },
+    logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union},
     or,
-    utils::{expr_to_columns, exprlist_to_columns, from_plan},
-    Expr, Operator, TableProviderFilterPushDown,
+    utils::from_plan,
+    BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown,
 };
 use std::collections::{HashMap, HashSet};
 use std::iter::once;
+use std::sync::Arc;
 
-/// Filter Push Down optimizer rule pushes filter clauses down the plan
+/// Push Down Filter optimizer rule pushes filter clauses down the plan
 /// # Introduction
 /// A filter-commutative operation is an operation whose result of filter(op(data)) = op(filter(data)).
 /// An example of a filter-commutative operation is a projection; a counter-example is `limit`.
@@ -57,96 +56,7 @@ use std::iter::once;
 /// When it passes through a projection, it re-writes the filter's expression taking into account that projection.
 /// When multiple filters would have been written, it `AND` their expressions into a single expression.
 #[derive(Default)]
-pub struct FilterPushDown {}
-
-/// Filter predicate represented by tuple of expression and its columns
-type Predicate = (Expr, HashSet<Column>);
-
-/// Multiple filter predicates represented by tuple of expressions vector
-/// and corresponding expression columns vector
-type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<Column>>);
-
-#[derive(Debug, Clone, Default)]
-struct State {
-    // (predicate, columns on the predicate)
-    filters: Vec<Predicate>,
-}
-
-impl State {
-    fn append_predicates(&mut self, predicates: Predicates) {
-        predicates
-            .0
-            .into_iter()
-            .zip(predicates.1)
-            .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone())))
-    }
-}
-
-/// returns all predicates in `state` that depend on any of `used_columns`
-/// or the ones that does not reference any columns (e.g. WHERE 1=1)
-fn get_predicates<'a>(
-    state: &'a State,
-    used_columns: &HashSet<Column>,
-) -> Predicates<'a> {
-    state
-        .filters
-        .iter()
-        .filter(|(_, columns)| {
-            columns.is_empty()
-                || !columns
-                    .intersection(used_columns)
-                    .collect::<HashSet<_>>()
-                    .is_empty()
-        })
-        .map(|&(ref a, ref b)| (a, b))
-        .unzip()
-}
-
-/// Optimizes the plan
-fn push_down(state: &State, plan: &LogicalPlan) -> Result<LogicalPlan> {
-    let new_inputs = plan
-        .inputs()
-        .iter()
-        .map(|input| optimize(input, state.clone()))
-        .collect::<Result<Vec<_>>>()?;
-
-    let expr = plan.expressions();
-    from_plan(plan, &expr, &new_inputs)
-}
-
-// remove all filters from `filters` that are in `predicate_columns`
-fn remove_filters(
-    filters: &[Predicate],
-    predicate_columns: &[&HashSet<Column>],
-) -> Vec<Predicate> {
-    filters
-        .iter()
-        .filter(|(_, columns)| !predicate_columns.contains(&columns))
-        .cloned()
-        .collect::<Vec<_>>()
-}
-
-/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters
-/// in `state` depend on the columns `used_columns`.
-fn issue_filters(
-    mut state: State,
-    used_columns: HashSet<Column>,
-    plan: &LogicalPlan,
-) -> Result<LogicalPlan> {
-    let (predicates, predicate_columns) = get_predicates(&state, &used_columns);
-
-    if predicates.is_empty() {
-        // all filters can be pushed down => optimize inputs and return new plan
-        return push_down(&state, plan);
-    }
-
-    let plan = utils::add_filter(plan.clone(), &predicates)?;
-
-    state.filters = remove_filters(&state.filters, &predicate_columns);
-
-    // continue optimization over all input nodes by cloning the current state (i.e. each node is independent)
-    push_down(&state, &plan)
-}
+pub struct PushDownFilter {}
 
 // For a given JOIN logical plan, determine whether each side of the join is preserved.
 // We say a join side is preserved if the join returns all or a subset of the rows from
@@ -220,15 +130,7 @@ fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
 // or not the side's rows are preserved when joining. If the side is not preserved, we
 // do not push down anything. Otherwise we can push down predicates where all of the
 // relevant columns are contained on the relevant join side's schema.
-fn get_pushable_join_predicates<'a>(
-    filters: &'a [Predicate],
-    schema: &DFSchema,
-    preserved: bool,
-) -> Predicates<'a> {
-    if !preserved {
-        return (vec![], vec![]);
-    }
-
+fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result<bool> {
     let schema_columns = schema
         .fields()
         .iter()
@@ -240,19 +142,13 @@ fn get_pushable_join_predicates<'a>(
             ]
         })
         .collect::<HashSet<_>>();
+    let columns = predicate.to_columns()?;
 
-    filters
-        .iter()
-        .filter(|(_, columns)| {
-            let all_columns_in_schema = schema_columns
-                .intersection(columns)
-                .collect::<HashSet<_>>()
-                .len()
-                == columns.len();
-            all_columns_in_schema
-        })
-        .map(|(a, b)| (a, b))
-        .unzip()
+    Ok(schema_columns
+        .intersection(&columns)
+        .collect::<HashSet<_>>()
+        .len()
+        == columns.len())
 }
 
 // examine OR clause to see if any useful clauses can be extracted and push down.
@@ -292,9 +188,9 @@ fn extract_or_clauses_for_join(
     filters: &[&Expr],
     schema: &DFSchema,
     preserved: bool,
-) -> (Vec<Expr>, Vec<HashSet<Column>>) {
+) -> Vec<Expr> {
     if !preserved {
-        return (vec![], vec![]);
+        return vec![];
     }
 
     let schema_columns = schema
@@ -310,7 +206,6 @@ fn extract_or_clauses_for_join(
         .collect::<HashSet<_>>();
 
     let mut exprs = vec![];
-    let mut expr_columns = vec![];
     for expr in filters.iter() {
         if let Expr::BinaryExpr(BinaryExpr {
             left,
@@ -323,17 +218,13 @@ fn extract_or_clauses_for_join(
 
             // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
             if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
-                let predicate = or(left_expr, right_expr);
-                let columns = predicate.to_columns().ok().unwrap();
-
-                exprs.push(predicate);
-                expr_columns.push(columns);
+                exprs.push(or(left_expr, right_expr));
             }
         }
     }
 
     // new formed OR clauses and their column references
-    (exprs, expr_columns)
+    exprs
 }
 
 // extract qual from OR sub-clause.
@@ -403,94 +294,89 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
     predicate
 }
 
-fn optimize_join(
-    mut state: State,
+// push down join/cross-join
+fn push_down_all_join(
+    predicates: Vec<Expr>,
     plan: &LogicalPlan,
     left: &LogicalPlan,
     right: &LogicalPlan,
-    on_filter: Vec<Predicate>,
+    on_filter: Vec<Expr>,
 ) -> Result<LogicalPlan> {
+    let on_filter_empty = on_filter.is_empty();
     // Get pushable predicates from current optimizer state
     let (left_preserved, right_preserved) = lr_is_preserved(plan)?;
-    let to_left =
-        get_pushable_join_predicates(&state.filters, left.schema(), left_preserved);
-    let to_right =
-        get_pushable_join_predicates(&state.filters, right.schema(), right_preserved);
-    let to_keep: Predicates = state
-        .filters
-        .iter()
-        .filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e))
-        .map(|(a, b)| (a, b))
-        .unzip();
+    let mut left_push = vec![];
+    let mut right_push = vec![];
+
+    let mut keep_predicates = vec![];
+    for predicate in predicates {
+        if left_preserved && can_pushdown_join_predicate(&predicate, left.schema())? {
+            left_push.push(predicate);
+        } else if right_preserved
+            && can_pushdown_join_predicate(&predicate, right.schema())?
+        {
+            right_push.push(predicate);
+        } else {
+            keep_predicates.push(predicate);
+        }
+    }
 
-    // Get pushable predicates from join filter
-    let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() {
-        ((vec![], vec![]), (vec![], vec![]), vec![])
-    } else {
+    let mut keep_condition = vec![];
+    if !on_filter.is_empty() {
         let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan)?;
-        let on_to_left =
-            get_pushable_join_predicates(&on_filter, left.schema(), on_left_preserved);
-        let on_to_right =
-            get_pushable_join_predicates(&on_filter, right.schema(), on_right_preserved);
-        let on_to_keep = on_filter
-            .iter()
-            .filter(|(e, _)| !on_to_left.0.contains(&e) && !on_to_right.0.contains(&e))
-            .map(|(a, _)| a.clone())
-            .collect::<Vec<_>>();
-
-        (on_to_left, on_to_right, on_to_keep)
-    };
+        for on in on_filter {
+            if on_left_preserved && can_pushdown_join_predicate(&on, left.schema())? {
+                left_push.push(on)
+            } else if on_right_preserved
+                && can_pushdown_join_predicate(&on, right.schema())?
+            {
+                right_push.push(on)
+            } else {
+                keep_condition.push(on)
+            }
+        }
+    }
 
     // Extract from OR clause, generate new predicates for both side of join if possible.
     // We only track the unpushable predicates above.
-    let or_to_left =
-        extract_or_clauses_for_join(&to_keep.0, left.schema(), left_preserved);
-    let or_to_right =
-        extract_or_clauses_for_join(&to_keep.0, right.schema(), right_preserved);
+    let or_to_left = extract_or_clauses_for_join(
+        &keep_predicates.iter().collect::<Vec<_>>(),
+        left.schema(),
+        left_preserved,
+    );
+    let or_to_right = extract_or_clauses_for_join(
+        &keep_predicates.iter().collect::<Vec<_>>(),
+        right.schema(),
+        right_preserved,
+    );
     let on_or_to_left = extract_or_clauses_for_join(
-        &on_to_keep.iter().collect::<Vec<_>>(),
+        &keep_condition.iter().collect::<Vec<_>>(),
         left.schema(),
         left_preserved,
     );
     let on_or_to_right = extract_or_clauses_for_join(
-        &on_to_keep.iter().collect::<Vec<_>>(),
+        &keep_condition.iter().collect::<Vec<_>>(),
         right.schema(),
         right_preserved,
     );
 
-    // Build new filter states using pushable predicates
-    // from current optimizer states and from ON clause.
-    // Then recursively call optimization for both join inputs
-    let mut left_state = State::default();
-    left_state.append_predicates(to_left);
-    left_state.append_predicates(on_to_left);
-    or_to_left
-        .0
-        .into_iter()
-        .zip(or_to_left.1)
-        .for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
-    on_or_to_left
-        .0
-        .into_iter()
-        .zip(on_or_to_left.1)
-        .for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
-    let left = optimize(left, left_state)?;
-
-    let mut right_state = State::default();
-    right_state.append_predicates(to_right);
-    right_state.append_predicates(on_to_right);
-    or_to_right
-        .0
-        .into_iter()
-        .zip(or_to_right.1)
-        .for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
-    on_or_to_right
-        .0
-        .into_iter()
-        .zip(on_or_to_right.1)
-        .for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
-    let right = optimize(right, right_state)?;
+    left_push.extend(or_to_left);
+    left_push.extend(on_or_to_left);
+    right_push.extend(or_to_right);
+    right_push.extend(on_or_to_right);
 
+    let left = match conjunction(left_push) {
+        Some(predicate) => {
+            LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?)
+        }
+        None => left.clone(),
+    };
+    let right = match conjunction(right_push) {
+        Some(predicate) => {
+            LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?)
+        }
+        None => right.clone(),
+    };
     // Create a new Join with the new `left` and `right`
     //
     // expressions() output for Join is a vector consisting of
@@ -500,302 +386,359 @@ fn optimize_join(
     //      vector will contain only join keys (without additional
     //      element representing filter).
     let expr = plan.expressions();
-    let expr = if !on_filter.is_empty() && on_to_keep.is_empty() {
+    let expr = if !on_filter_empty && keep_condition.is_empty() {
         // New filter expression is None - should remove last element
         expr[..expr.len() - 1].to_vec()
-    } else if !on_to_keep.is_empty() {
+    } else if !keep_condition.is_empty() {
         // Replace last element with new filter expression
         expr[..expr.len() - 1]
             .iter()
             .cloned()
-            .chain(once(on_to_keep.into_iter().reduce(Expr::and).unwrap()))
+            .chain(once(keep_condition.into_iter().reduce(Expr::and).unwrap()))
             .collect()
     } else {
         plan.expressions()
     };
     let plan = from_plan(plan, &expr, &[left, right])?;
 
-    if to_keep.0.is_empty() {
+    if keep_predicates.is_empty() {
         Ok(plan)
     } else {
         // wrap the join on the filter whose predicates must be kept
-        let plan = utils::add_filter(plan, &to_keep.0)?;
-        state.filters = remove_filters(&state.filters, &to_keep.1);
-
-        Ok(plan)
+        match conjunction(keep_predicates) {
+            Some(predicate) => Ok(LogicalPlan::Filter(Filter::try_new(
+                predicate,
+                Arc::new(plan),
+            )?)),
+            None => Ok(plan),
+        }
     }
 }
 
-fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
-    match plan {
-        LogicalPlan::Explain { .. } => {
-            // push the optimization to the plan of this explain
-            push_down(&state, plan)
-        }
-        LogicalPlan::Analyze { .. } => push_down(&state, plan),
-        LogicalPlan::Filter(filter) => {
-            let predicate = utils::cnf_rewrite(filter.predicate().clone());
-
-            utils::split_conjunction_owned(predicate)
-                .into_iter()
-                .try_for_each::<_, Result<()>>(|predicate| {
-                    let columns = predicate.to_columns()?;
-                    state.filters.push((predicate, columns));
-                    Ok(())
-                })?;
-
-            optimize(filter.input(), state)
+fn push_down_join(
+    plan: &LogicalPlan,
+    join: &Join,
+    parent_predicate: Option<&Expr>,
+) -> Result<Option<LogicalPlan>> {
+    let mut predicates = match parent_predicate {
+        Some(parent_predicate) => {
+            utils::split_conjunction_owned(utils::cnf_rewrite(parent_predicate.clone()))
         }
-        LogicalPlan::Projection(Projection {
-            input,
-            expr,
-            schema,
-        }) => {
-            // A projection is filter-commutable, but re-writes all predicate expressions
-            // collect projection.
-            let projection = schema
-                .fields()
-                .iter()
-                .enumerate()
-                .flat_map(|(i, field)| {
-                    // strip alias, as they should not be part of filters
-                    let expr = match &expr[i] {
-                        Expr::Alias(expr, _) => expr.as_ref().clone(),
-                        expr => expr.clone(),
+        None => vec![],
+    };
+
+    // Convert JOIN ON predicate to Predicates
+    let on_filters = join
+        .filter
+        .as_ref()
+        .map(|e| utils::split_conjunction_owned(e.clone()))
+        .unwrap_or_else(Vec::new);
+
+    if join.join_type == JoinType::Inner {
+        // For inner joins, duplicate filters for joined columns so filters can be pushed down
+        // to both sides. Take the following query as an example:
+        //
+        // ```sql
+        // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
+        // ```
+        //
+        // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
+        // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
+        //
+        // Join clauses with `Using` constraints also take advantage of this logic to make sure
+        // predicates reference the shared join columns are pushed to both sides.
+        // This logic should also been applied to conditions in JOIN ON clause
+        let join_side_filters = predicates
+            .iter()
+            .chain(on_filters.iter())
+            .filter_map(|predicate| {
+                let mut join_cols_to_replace = HashMap::new();
+                let columns = match predicate.to_columns() {
+                    Ok(columns) => columns,
+                    Err(e) => return Some(Err(e)),
+                };
+
+                for col in columns.iter() {
+                    for (l, r) in join.on.iter() {
+                        if col == l {
+                            join_cols_to_replace.insert(col, r);
+                            break;
+                        } else if col == r {
+                            join_cols_to_replace.insert(col, l);
+                            break;
+                        }
+                    }
+                }
+
+                if join_cols_to_replace.is_empty() {
+                    return None;
+                }
+
+                let join_side_predicate =
+                    match replace_col(predicate.clone(), &join_cols_to_replace) {
+                        Ok(p) => p,
+                        Err(e) => {
+                            return Some(Err(e));
+                        }
                     };
 
-                    // Convert both qualified and unqualified fields
-                    [
-                        (field.name().clone(), expr.clone()),
-                        (field.qualified_name(), expr),
-                    ]
-                })
-                .collect::<HashMap<_, _>>();
+                Some(Ok(join_side_predicate))
+            })
+            .collect::<Result<Vec<_>>>()?;
+        predicates.extend(join_side_filters);
+    }
+    if on_filters.is_empty() && predicates.is_empty() {
+        return Ok(None);
+    }
+    Ok(Some(push_down_all_join(
+        predicates,
+        plan,
+        &join.left,
+        &join.right,
+        on_filters,
+    )?))
+}
 
-            // re-write all filters based on this projection
-            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
-            for (predicate, columns) in state.filters.iter_mut() {
-                *predicate = replace_cols_by_name(predicate.clone(), &projection)?;
+impl OptimizerRule for PushDownFilter {
+    fn name(&self) -> &str {
+        "push_down_filter"
+    }
 
-                columns.clear();
-                expr_to_columns(predicate, columns)?;
+    fn optimize(
+        &self,
+        plan: &LogicalPlan,
+        optimizer_config: &mut OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        let filter = match plan {
+            LogicalPlan::Filter(filter) => filter,
+            // we also need to pushdown filter in Join.
+            LogicalPlan::Join(join) => {
+                let optimized_plan = push_down_join(plan, join, None)?;
+                return match optimized_plan {
+                    Some(optimized_plan) => {
+                        utils::optimize_children(self, &optimized_plan, optimizer_config)
+                    }
+                    None => utils::optimize_children(self, plan, optimizer_config),
+                };
             }
+            _ => return utils::optimize_children(self, plan, optimizer_config),
+        };
 
-            // optimize inner
-            let new_input = optimize(input, state)?;
-            Ok(from_plan(plan, expr, &[new_input])?)
-        }
-        LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => {
-            // An aggregate's aggreagate columns are _not_ filter-commutable => collect these:
-            // * columns whose aggregation expression depends on
-            // * the aggregation columns themselves
-
-            // construct set of columns that `aggr_expr` depends on
-            let mut used_columns = HashSet::new();
-            exprlist_to_columns(aggr_expr, &mut used_columns)?;
-
-            let agg_columns = aggr_expr
-                .iter()
-                .map(|x| Ok(Column::from_name(x.display_name()?)))
-                .collect::<Result<HashSet<_>>>()?;
-            used_columns.extend(agg_columns);
-
-            issue_filters(state, used_columns, plan)
-        }
-        LogicalPlan::Sort { .. } => {
-            // sort is filter-commutable
-            push_down(&state, plan)
-        }
-        LogicalPlan::Union(Union { inputs: _, schema }) => {
-            // union changing all qualifiers while building logical plan so we need
-            // to rewrite filters to push unqualified columns to inputs
-            let projection = schema
-                .fields()
-                .iter()
-                .map(|field| (field.qualified_name(), col(field.name())))
-                .collect::<HashMap<_, _>>();
-
-            // rewriting predicate expressions using unqualified names as replacements
-            if !projection.is_empty() {
-                for (predicate, columns) in state.filters.iter_mut() {
-                    *predicate = replace_cols_by_name(predicate.clone(), &projection)?;
-
-                    columns.clear();
-                    expr_to_columns(predicate, columns)?;
+        let child_plan = &**filter.input();
+        let new_plan = match child_plan {
+            LogicalPlan::Filter(child_filter) => {
+                let new_predicate =
+                    and(filter.predicate().clone(), child_filter.predicate().clone());
+                let new_plan = LogicalPlan::Filter(Filter::try_new(
+                    new_predicate,
+                    child_filter.input().clone(),
+                )?);
+                return self.optimize(&new_plan, optimizer_config);
+            }
+            LogicalPlan::Repartition(_)
+            | LogicalPlan::Distinct(_)
+            | LogicalPlan::Sort(_) => {
+                // commutable
+                let new_filter =
+                    plan.with_new_inputs(&[
+                        (**(child_plan.inputs().get(0).unwrap())).clone()
+                    ])?;
+                child_plan.with_new_inputs(&[new_filter])?
+            }
+            LogicalPlan::SubqueryAlias(subquery_alias) => {
+                let mut replace_map = HashMap::new();
+                for (i, field) in
+                    subquery_alias.input.schema().fields().iter().enumerate()
+                {
+                    replace_map.insert(
+                        subquery_alias
+                            .schema
+                            .fields()
+                            .get(i)
+                            .unwrap()
+                            .qualified_name(),
+                        Expr::Column(field.qualified_column()),
+                    );
                 }
+                let new_predicate =
+                    replace_cols_by_name(filter.predicate().clone(), &replace_map)?;
+                let new_filter = LogicalPlan::Filter(Filter::try_new(
+                    new_predicate,
+                    subquery_alias.input.clone(),
+                )?);
+                child_plan.with_new_inputs(&[new_filter])?
             }
-
-            push_down(&state, plan)
-        }
-        LogicalPlan::Limit(Limit { input, .. }) => {
-            // limit is _not_ filter-commutable => collect all columns from its input
-            let used_columns = input
-                .schema()
-                .fields()
-                .iter()
-                .map(|f| f.qualified_column())
-                .collect::<HashSet<_>>();
-            issue_filters(state, used_columns, plan)
-        }
-        LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
-            optimize_join(state, plan, left, right, vec![])
-        }
-        LogicalPlan::Join(Join {
-            left,
-            right,
-            on,
-            filter,
-            join_type,
-            ..
-        }) => {
-            // Convert JOIN ON predicate to Predicates
-            let on_filters = filter
-                .as_ref()
-                .map(|e| {
-                    let predicates = utils::split_conjunction(e);
-
-                    predicates
-                        .into_iter()
-                        .map(|e| Ok((e.clone(), e.to_columns()?)))
-                        .collect::<Result<Vec<_>>>()
-                })
-                .unwrap_or_else(|| Ok(vec![]))?;
-
-            if *join_type == JoinType::Inner {
-                // For inner joins, duplicate filters for joined columns so filters can be pushed down
-                // to both sides. Take the following query as an example:
-                //
-                // ```sql
-                // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
-                // ```
-                //
-                // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
-                // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
-                //
-                // Join clauses with `Using` constraints also take advantage of this logic to make sure
-                // predicates reference the shared join columns are pushed to both sides.
-                // This logic should also been applied to conditions in JOIN ON clause
-                let join_side_filters = state
-                    .filters
+            LogicalPlan::Projection(projection) => {
+                // A projection is filter-commutable, but re-writes all predicate expressions
+                // collect projection.
+                let replace_map = projection
+                    .schema
+                    .fields()
                     .iter()
-                    .chain(on_filters.iter())
-                    .filter_map(|(predicate, columns)| {
-                        let mut join_cols_to_replace = HashMap::new();
-                        for col in columns.iter() {
-                            for (l, r) in on {
-                                if col == l {
-                                    join_cols_to_replace.insert(col, r);
-                                    break;
-                                } else if col == r {
-                                    join_cols_to_replace.insert(col, l);
-                                    break;
-                                }
-                            }
-                        }
+                    .enumerate()
+                    .map(|(i, field)| {
+                        // strip alias, as they should not be part of filters
+                        let expr = match &projection.expr[i] {
+                            Expr::Alias(expr, _) => expr.as_ref().clone(),
+                            expr => expr.clone(),
+                        };
+
+                        (field.qualified_name(), expr)
+                    })
+                    .collect::<HashMap<_, _>>();
 
-                        if join_cols_to_replace.is_empty() {
-                            return None;
-                        }
+                // re-write all filters based on this projection
+                // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
+                let new_filter = LogicalPlan::Filter(Filter::try_new(
+                    replace_cols_by_name(filter.predicate().clone(), &replace_map)?,
+                    projection.input.clone(),
+                )?);
 
-                        let join_side_predicate =
-                            match replace_col(predicate.clone(), &join_cols_to_replace) {
-                                Ok(p) => p,
-                                Err(e) => {
-                                    return Some(Err(e));
-                                }
-                            };
-
-                        let join_side_columns = columns
-                            .clone()
-                            .into_iter()
-                            // replace keys in join_cols_to_replace with values in resulting column
-                            // set
-                            .filter(|c| !join_cols_to_replace.contains_key(c))
-                            .chain(join_cols_to_replace.values().map(|v| (*v).clone()))
-                            .collect();
-
-                        Some(Ok((join_side_predicate, join_side_columns)))
-                    })
-                    .collect::<Result<Vec<_>>>()?;
-                state.filters.extend(join_side_filters);
+                child_plan.with_new_inputs(&[new_filter])?
             }
+            LogicalPlan::Union(union) => {
+                let mut inputs = Vec::with_capacity(union.inputs.len());
+                for input in &union.inputs {
+                    let mut replace_map = HashMap::new();
+                    for (i, field) in input.schema().fields().iter().enumerate() {
+                        replace_map.insert(
+                            union.schema.fields().get(i).unwrap().qualified_name(),
+                            Expr::Column(field.qualified_column()),
+                        );
+                    }
 
-            optimize_join(state, plan, left, right, on_filters)
-        }
-        LogicalPlan::TableScan(TableScan {
-            source,
-            projected_schema,
-            filters,
-            projection,
-            table_name,
-            fetch,
-        }) => {
-            let mut used_columns = HashSet::new();
-            let mut new_filters = filters.clone();
-
-            for (filter_expr, cols) in &state.filters {
-                let (preserve_filter_node, add_to_provider) =
-                    match source.supports_filter_pushdown(filter_expr)? {
-                        TableProviderFilterPushDown::Unsupported => (true, false),
-                        TableProviderFilterPushDown::Inexact => (true, true),
-                        TableProviderFilterPushDown::Exact => (false, true),
-                    };
-
-                if preserve_filter_node {
-                    used_columns.extend(cols.clone());
+                    let push_predicate =
+                        replace_cols_by_name(filter.predicate().clone(), &replace_map)?;
+                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
+                        push_predicate,
+                        input.clone(),
+                    )?)))
                 }
-
-                if add_to_provider {
-                    // Don't add expression again if it's already present in
-                    // pushed down filters.
-                    if new_filters.contains(filter_expr) {
-                        continue;
+                LogicalPlan::Union(Union {
+                    inputs,
+                    schema: plan.schema().clone(),
+                })
+            }
+            LogicalPlan::Aggregate(agg) => {
+                // An aggregate's aggregate columns are _not_ filter-commutable => collect these:
+                // * columns whose aggregation expression depends on
+                // * the aggregation columns themselves
+
+                // construct set of columns that `aggr_expr` depends on
+                let mut used_columns = HashSet::new();
+                exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?;
+                let agg_columns = agg
+                    .aggr_expr
+                    .iter()
+                    .map(|x| Ok(Column::from_name(x.display_name()?)))
+                    .collect::<Result<HashSet<_>>>()?;
+                used_columns.extend(agg_columns);
+
+                let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
+                    filter.predicate().clone(),
+                ));
+
+                let mut keep_predicates = vec![];
+                let mut push_predicates = vec![];
+                for expr in predicates {
+                    let columns = expr.to_columns()?;
+                    if columns.is_empty()
+                        || !columns
+                            .intersection(&used_columns)
+                            .collect::<HashSet<_>>()
+                            .is_empty()
+                    {
+                        keep_predicates.push(expr);
+                    } else {
+                        push_predicates.push(expr);
                     }
-                    new_filters.push(filter_expr.clone());
+                }
+
+                let child = match conjunction(push_predicates) {
+                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
+                        predicate,
+                        Arc::new((*agg.input).clone()),
+                    )?),
+                    None => (*agg.input).clone(),
+                };
+                let new_agg = from_plan(
+                    filter.input(),
+                    &filter.input().expressions(),
+                    &vec![child],
+                )?;
+                match conjunction(keep_predicates) {
+                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
+                        predicate,
+                        Arc::new(new_agg),
+                    )?),
+                    None => new_agg,
+                }
+            }
+            LogicalPlan::Join(join) => {
+                match push_down_join(filter.input(), join, Some(filter.predicate()))? {
+                    Some(optimized_plan) => optimized_plan,
+                    None => plan.clone(),
                 }
             }
+            LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
+                let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
+                    filter.predicate().clone(),
+                ));
 
-            issue_filters(
-                state,
-                used_columns,
-                &LogicalPlan::TableScan(TableScan {
-                    source: source.clone(),
-                    projection: projection.clone(),
-                    projected_schema: projected_schema.clone(),
-                    table_name: table_name.clone(),
-                    filters: new_filters,
-                    fetch: *fetch,
-                }),
-            )
-        }
-        _ => {
-            // all other plans are _not_ filter-commutable
-            let used_columns = plan
-                .schema()
-                .fields()
-                .iter()
-                .map(|f| f.qualified_column())
-                .collect::<HashSet<_>>();
-            issue_filters(state, used_columns, plan)
-        }
-    }
-}
+                push_down_all_join(predicates, filter.input(), left, right, vec![])?
+            }
+            LogicalPlan::TableScan(scan) => {
+                let mut new_scan_filters = scan.filters.clone();
+                let mut new_predicate = vec![];
+
+                let filter_predicates = utils::split_conjunction_owned(
+                    utils::cnf_rewrite(filter.predicate().clone()),
+                );
+
+                for filter_expr in &filter_predicates {
+                    let (preserve_filter_node, add_to_provider) =
+                        match scan.source.supports_filter_pushdown(filter_expr)? {
+                            TableProviderFilterPushDown::Unsupported => (true, false),
+                            TableProviderFilterPushDown::Inexact => (true, true),
+                            TableProviderFilterPushDown::Exact => (false, true),
+                        };
+                    if preserve_filter_node {
+                        new_predicate.push(filter_expr.clone());
+                    }
+                    if add_to_provider {
+                        // avoid reduplicated filter expr.
+                        if new_scan_filters.contains(filter_expr) {
+                            continue;
+                        }
+                        new_scan_filters.push(filter_expr.clone());
+                    }
+                }
 
-impl OptimizerRule for FilterPushDown {
-    fn name(&self) -> &str {
-        "filter_push_down"
-    }
+                let new_scan = LogicalPlan::TableScan(TableScan {
+                    source: scan.source.clone(),
+                    projection: scan.projection.clone(),
+                    projected_schema: scan.projected_schema.clone(),
+                    table_name: scan.table_name.clone(),
+                    filters: new_scan_filters,
+                    fetch: scan.fetch,
+                });
+
+                match conjunction(new_predicate) {
+                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
+                        predicate,
+                        Arc::new(new_scan),
+                    )?),
+                    None => new_scan,
+                }
+            }
+            _ => plan.clone(),
+        };
 
-    fn optimize(
-        &self,
-        plan: &LogicalPlan,
-        _: &mut OptimizerConfig,
-    ) -> Result<LogicalPlan> {
-        optimize(plan, State::default())
+        utils::optimize_children(self, &new_plan, optimizer_config)
     }
 }
 
-impl FilterPushDown {
+impl PushDownFilter {
     #[allow(missing_docs)]
     pub fn new() -> Self {
         Self {}
@@ -828,25 +771,24 @@ fn replace_cols_by_name(e: Expr, replace_map: &HashMap<String, Expr>) -> Result<
 mod tests {
     use super::*;
     use crate::test::*;
-    use arrow::datatypes::SchemaRef;
+    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
     use async_trait::async_trait;
     use datafusion_common::DFSchema;
+    use datafusion_expr::logical_plan::table_scan;
     use datafusion_expr::{
-        and, col, in_list, in_subquery, lit, logical_plan::JoinType, sum, Expr,
-        LogicalPlanBuilder, Operator, TableSource, TableType,
+        and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr,
+        Expr, LogicalPlanBuilder, Operator, TableSource, TableType,
     };
     use std::sync::Arc;
 
-    fn optimize_plan(plan: &LogicalPlan) -> LogicalPlan {
-        let rule = FilterPushDown::new();
-        rule.optimize(plan, &mut OptimizerConfig::new())
-            .expect("failed to optimize plan")
-    }
-
-    fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
-        let optimized_plan = optimize_plan(plan);
+    fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+        let optimized_plan = PushDownFilter::new()
+            .optimize(plan, &mut OptimizerConfig::new())
+            .expect("failed to optimize plan");
         let formatted_plan = format!("{:?}", optimized_plan);
-        assert_eq!(formatted_plan, expected);
+        assert_eq!(plan.schema(), optimized_plan.schema());
+        assert_eq!(expected, formatted_plan);
+        Ok(())
     }
 
     #[test]
@@ -861,8 +803,7 @@ mod tests {
             Projection: test.a, test.b\
             \n  Filter: test.a = Int64(1)\
             \n    TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -879,8 +820,7 @@ mod tests {
             \n  Limit: skip=0, fetch=10\
             \n    Projection: test.a, test.b\
             \n      TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -892,8 +832,7 @@ mod tests {
         let expected = "\
             Filter: Int64(0) = Int64(1)\
             \n  TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -910,8 +849,7 @@ mod tests {
             \n  Projection: test.a, test.b, test.c\
             \n    Filter: test.a = Int64(1)\
             \n      TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -926,8 +864,20 @@ mod tests {
             Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\
             \n  Filter: test.a > Int64(10)\
             \n    TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn filter_complex_group_by() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
+            .filter(col("b").gt(lit(10i64)))?
+            .build()?;
+        let expected = "Filter: test.b > Int64(10)\
+        \n  Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\
+        \n    TableScan: test";
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -942,8 +892,7 @@ mod tests {
             Filter: b > Int64(10)\
             \n  Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
             \n    TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -966,8 +915,7 @@ mod tests {
         \n  Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
         \n    Filter: test.c = Int64(1) OR test.c = Int64(1)\
         \n      TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
@@ -983,8 +931,7 @@ mod tests {
             Projection: test.a AS b, test.c\
             \n  Filter: test.a = Int64(1)\
             \n    TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     fn add(left: Expr, right: Expr) -> Expr {
@@ -1029,8 +976,7 @@ mod tests {
             Projection: test.a * Int32(2) + test.c AS b, test.c\
             \n  Filter: test.a * Int32(2) + test.c = Int64(1)\
             \n    TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
@@ -1063,8 +1009,7 @@ mod tests {
         \n  Projection: test.a * Int32(2) + test.c AS b, test.c\
         \n    Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\
         \n      TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
@@ -1098,9 +1043,7 @@ mod tests {
         \n    Projection: test.a AS b, test.c\
         \n      Filter: test.a > Int64(10)\
         \n        TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
@@ -1135,9 +1078,7 @@ mod tests {
         \n    Projection: test.a AS b, test.c\
         \n      Filter: test.a > Int64(10)\
         \n        TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// verifies that when two limits are in place, we jump neither
@@ -1159,26 +1100,24 @@ mod tests {
             \n      Limit: skip=0, fetch=20\
             \n        Projection: test.a, test.b\
             \n          TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
     fn union_all() -> Result<()> {
         let table_scan = test_table_scan()?;
-        let plan = LogicalPlanBuilder::from(table_scan.clone())
-            .union(LogicalPlanBuilder::from(table_scan).build()?)?
+        let table_scan2 = test_table_scan_with_name("test2")?;
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
             .filter(col("a").eq(lit(1i64)))?
             .build()?;
         // filter appears below Union
-        let expected = "\
-            Union\
-            \n  Filter: a = Int64(1)\
-            \n    TableScan: test\
-            \n  Filter: a = Int64(1)\
-            \n    TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        let expected = "Union\
+        \n  Filter: test.a = Int64(1)\
+        \n    TableScan: test\
+        \n  Filter: test2.a = Int64(1)\
+        \n    TableScan: test2";
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -1193,8 +1132,7 @@ mod tests {
             .build()?;
 
         // filter appears below Union
-        let expected = "Union\
-        \n  SubqueryAlias: test2\
+        let expected = "Union\n  SubqueryAlias: test2\
         \n    Projection: test.a AS b\
         \n      Filter: test.a = Int64(1)\
         \n        TableScan: test\
@@ -1202,8 +1140,68 @@ mod tests {
         \n    Projection: test.a AS b\
         \n      Filter: test.a = Int64(1)\
         \n        TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn test_union_different_schema() -> Result<()> {
+        let left = LogicalPlanBuilder::from(test_table_scan()?)
+            .project(vec![col("a"), col("b"), col("c")])?
+            .build()?;
+
+        let schema = Schema::new(vec![
+            Field::new("d", DataType::UInt32, false),
+            Field::new("e", DataType::UInt32, false),
+            Field::new("f", DataType::UInt32, false),
+        ]);
+        let right = table_scan(Some("test1"), &schema, None)?
+            .project(vec![col("d"), col("e"), col("f")])?
+            .build()?;
+        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
+        let plan = LogicalPlanBuilder::from(left)
+            .cross_join(&right)?
+            .project(vec![col("test.a"), col("test1.d")])?
+            .filter(filter)?
+            .build()?;
+
+        let expected = "Projection: test.a, test1.d\
+        \n  CrossJoin:\
+        \n    Projection: test.a, test.b, test.c\
+        \n      Filter: test.a = Int32(1)\
+        \n        TableScan: test\
+        \n    Projection: test1.d, test1.e, test1.f\
+        \n      Filter: test1.d > Int32(2)\
+        \n        TableScan: test1";
+
+        assert_optimized_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn test_project_same_name_different_qualifier() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), col("b"), col("c")])?
+            .build()?;
+        let right_table_scan = test_table_scan_with_name("test1")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a"), col("b"), col("c")])?
+            .build()?;
+        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
+        let plan = LogicalPlanBuilder::from(left)
+            .cross_join(&right)?
+            .project(vec![col("test.a"), col("test1.a")])?
+            .filter(filter)?
+            .build()?;
+
+        let expected = "Projection: test.a, test1.a\
+        \n  CrossJoin:\
+        \n    Projection: test.a, test.b, test.c\
+        \n      Filter: test.a = Int32(1)\
+        \n        TableScan: test\
+        \n    Projection: test1.a, test1.b, test1.c\
+        \n      Filter: test1.a > Int32(2)\
+        \n        TableScan: test1";
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// verifies that filters with the same columns are correctly placed
@@ -1238,8 +1236,7 @@ mod tests {
         \n        Filter: test.a <= Int64(1)\
         \n          TableScan: test";
 
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// verifies that filters to be placed on the same depth are ANDed
@@ -1269,8 +1266,7 @@ mod tests {
         \n    Limit: skip=0, fetch=1\
         \n      TableScan: test";
 
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// verifies that filters on a plan with user nodes are not lost
@@ -1292,8 +1288,7 @@ mod tests {
         // not part of the test
         assert_eq!(format!("{:?}", plan), expected);
 
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// post-on-join predicates on a column common to both sides is pushed to both sides
@@ -1318,8 +1313,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: test.a <= Int64(1)\
+            "Filter: test.a <= Int64(1)\
             \n  Inner Join: test.a = test2.a\
             \n    TableScan: test\
             \n    Projection: test2.a\
@@ -1334,8 +1328,7 @@ mod tests {
         \n  Projection: test2.a\
         \n    Filter: test2.a <= Int64(1)\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// post-using-join predicates on a column common to both sides is pushed to both sides
@@ -1359,8 +1352,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: test.a <= Int64(1)\
+            "Filter: test.a <= Int64(1)\
             \n  Inner Join: Using test.a = test2.a\
             \n    TableScan: test\
             \n    Projection: test2.a\
@@ -1375,8 +1367,7 @@ mod tests {
         \n  Projection: test2.a\
         \n    Filter: test2.a <= Int64(1)\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// post-join predicates with columns from both sides are not pushed
@@ -1404,8 +1395,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: test.c <= test2.b\
+            "Filter: test.c <= test2.b\
             \n  Inner Join: test.a = test2.a\
             \n    Projection: test.a, test.c\
             \n      TableScan: test\
@@ -1415,8 +1405,7 @@ mod tests {
 
         // expected is equal: no push-down
         let expected = &format!("{:?}", plan);
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// post-join predicates with columns from one side of a join are pushed only to that side
@@ -1444,8 +1433,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: test.b <= Int64(1)\
+            "Filter: test.b <= Int64(1)\
             \n  Inner Join: test.a = test2.a\
             \n    Projection: test.a, test.b\
             \n      TableScan: test\
@@ -1460,8 +1448,7 @@ mod tests {
         \n      TableScan: test\
         \n  Projection: test2.a, test2.c\
         \n    TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// post-join predicates on the right side of a left join are not duplicated
@@ -1486,8 +1473,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: test2.a <= Int64(1)\
+            "Filter: test2.a <= Int64(1)\
             \n  Left Join: Using test.a = test2.a\
             \n    TableScan: test\
             \n    Projection: test2.a\
@@ -1501,12 +1487,10 @@ mod tests {
         \n    TableScan: test\
         \n    Projection: test2.a\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// post-join predicates on the left side of a right join are not duplicated
-    /// TODO: In this case we can sometimes convert the join to an INNER join
     #[test]
     fn filter_using_right_join() -> Result<()> {
         let table_scan = test_table_scan()?;
@@ -1527,8 +1511,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: test.a <= Int64(1)\
+            "Filter: test.a <= Int64(1)\
             \n  Right Join: Using test.a = test2.a\
             \n    TableScan: test\
             \n    Projection: test2.a\
@@ -1542,8 +1525,7 @@ mod tests {
         \n    TableScan: test\
         \n    Projection: test2.a\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// post-left-join predicate on a column common to both sides is only pushed to the left side
@@ -1568,8 +1550,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: test.a <= Int64(1)\
+            "Filter: test.a <= Int64(1)\
             \n  Left Join: Using test.a = test2.a\
             \n    TableScan: test\
             \n    Projection: test2.a\
@@ -1583,8 +1564,7 @@ mod tests {
         \n    TableScan: test\
         \n  Projection: test2.a\
         \n    TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// post-right-join predicate on a column common to both sides is only pushed to the right side
@@ -1609,8 +1589,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: test2.a <= Int64(1)\
+            "Filter: test2.a <= Int64(1)\
             \n  Right Join: Using test.a = test2.a\
             \n    TableScan: test\
             \n    Projection: test2.a\
@@ -1624,8 +1603,7 @@ mod tests {
         \n  Projection: test2.a\
         \n    Filter: test2.a <= Int64(1)\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// single table predicate parts of ON condition should be pushed to both inputs
@@ -1655,8 +1633,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
+            "Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
             \n  Projection: test.a, test.b, test.c\
             \n    TableScan: test\
             \n  Projection: test2.a, test2.b, test2.c\
@@ -1671,8 +1648,7 @@ mod tests {
         \n  Projection: test2.a, test2.b, test2.c\
         \n    Filter: test2.c > UInt32(4)\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// join filter should be completely removed after pushdown
@@ -1701,8 +1677,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\
+            "Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\
             \n  Projection: test.a, test.b, test.c\
             \n    TableScan: test\
             \n  Projection: test2.a, test2.b, test2.c\
@@ -1717,8 +1692,7 @@ mod tests {
         \n  Projection: test2.a, test2.b, test2.c\
         \n    Filter: test2.c > UInt32(4)\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// predicate on join key in filter expression should be pushed down to both inputs
@@ -1745,8 +1719,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\
+            "Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\
             \n  Projection: test.a\
             \n    TableScan: test\
             \n  Projection: test2.b\
@@ -1761,8 +1734,7 @@ mod tests {
         \n  Projection: test2.b\
         \n    Filter: test2.b > UInt32(1)\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// single table predicate parts of ON condition should be pushed to right input
@@ -1792,8 +1764,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
+            "Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
             \n  Projection: test.a, test.b, test.c\
             \n    TableScan: test\
             \n  Projection: test2.a, test2.b, test2.c\
@@ -1807,8 +1778,7 @@ mod tests {
         \n  Projection: test2.a, test2.b, test2.c\
         \n    Filter: test2.c > UInt32(4)\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// single table predicate parts of ON condition should be pushed to left input
@@ -1838,8 +1808,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
+            "Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
             \n  Projection: test.a, test.b, test.c\
             \n    TableScan: test\
             \n  Projection: test2.a, test2.b, test2.c\
@@ -1853,8 +1822,7 @@ mod tests {
         \n      TableScan: test\
         \n  Projection: test2.a, test2.b, test2.c\
         \n    TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// single table predicate parts of ON condition should not be pushed
@@ -1884,8 +1852,7 @@ mod tests {
         // not part of the test, just good to know:
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
+            "Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
             \n  Projection: test.a, test.b, test.c\
             \n    TableScan: test\
             \n  Projection: test2.a, test2.b, test2.c\
@@ -1893,8 +1860,7 @@ mod tests {
         );
 
         let expected = &format!("{:?}", plan);
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     struct PushDownProvider {
@@ -1961,8 +1927,7 @@ mod tests {
 
         let expected = "\
         TableScan: test, full_filters=[a = Int64(1)]";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -1973,8 +1938,7 @@ mod tests {
         let expected = "\
         Filter: a = Int64(1)\
         \n  TableScan: test, partial_filters=[a = Int64(1)]";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -1982,7 +1946,9 @@ mod tests {
         let plan =
             table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
 
-        let optimised_plan = optimize_plan(&plan);
+        let optimised_plan = PushDownFilter::new()
+            .optimize(&plan, &mut OptimizerConfig::new())
+            .expect("failed to optimize plan");
 
         let expected = "\
         Filter: a = Int64(1)\
@@ -1990,8 +1956,7 @@ mod tests {
 
         // Optimizing the same plan multiple times should produce the same plan
         // each time.
-        assert_optimized_plan_eq(&optimised_plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&optimised_plan, expected)
     }
 
     #[test]
@@ -2002,8 +1967,7 @@ mod tests {
         let expected = "\
         Filter: a = Int64(1)\
         \n  TableScan: test";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -2028,13 +1992,11 @@ mod tests {
             .project(vec![col("a"), col("b")])?
             .build()?;
 
-        let expected ="Projection: a, b\
+        let expected = "Projection: a, b\
             \n  Filter: a = Int64(10) AND b > Int64(11)\
             \n    TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]";
 
-        assert_optimized_plan_eq(&plan, expected);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -2051,11 +2013,9 @@ mod tests {
         // filter on col b
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: b > Int64(10) AND test.c > Int64(10)\
+            "Filter: b > Int64(10) AND test.c > Int64(10)\
             \n  Projection: test.a AS b, test.c\
-            \n    TableScan: test\
-            "
+            \n    TableScan: test"
         );
 
         // rewrite filter col b to test.a
@@ -2065,9 +2025,7 @@ mod tests {
             \n    TableScan: test\
             ";
 
-        assert_optimized_plan_eq(&plan, expected);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -2085,8 +2043,7 @@ mod tests {
         // filter on col b
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: b > Int64(10) AND test.c > Int64(10)\
+            "Filter: b > Int64(10) AND test.c > Int64(10)\
             \n  Projection: b, test.c\
             \n    Projection: test.a AS b, test.c\
             \n      TableScan: test\
@@ -2101,9 +2058,7 @@ mod tests {
             \n      TableScan: test\
             ";
 
-        assert_optimized_plan_eq(&plan, expected);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -2117,8 +2072,7 @@ mod tests {
         // filter on col b and d
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: b > Int64(10) AND d > Int64(10)\
+            "Filter: b > Int64(10) AND d > Int64(10)\
             \n  Projection: test.a AS b, test.c AS d\
             \n    TableScan: test\
             "
@@ -2131,9 +2085,7 @@ mod tests {
             \n    TableScan: test\
             ";
 
-        assert_optimized_plan_eq(&plan, expected);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     /// predicate on join key in filter expression should be pushed down to both inputs
@@ -2159,8 +2111,7 @@ mod tests {
 
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Inner Join: c = d Filter: c > UInt32(1)\
+            "Inner Join: c = d Filter: c > UInt32(1)\
             \n  Projection: test.a AS c\
             \n    TableScan: test\
             \n  Projection: test2.b AS d\
@@ -2176,8 +2127,7 @@ mod tests {
         \n  Projection: test2.b AS d\
         \n    Filter: test2.b > UInt32(1)\
         \n      TableScan: test2";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -2195,8 +2145,7 @@ mod tests {
         // filter on col b
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
+            "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
             \n  Projection: test.a AS b, test.c\
             \n    TableScan: test\
             "
@@ -2209,9 +2158,7 @@ mod tests {
             \n    TableScan: test\
             ";
 
-        assert_optimized_plan_eq(&plan, expected);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -2230,8 +2177,7 @@ mod tests {
         // filter on col b
         assert_eq!(
             format!("{:?}", plan),
-            "\
-            Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
+            "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
             \n  Projection: b, test.c\
             \n    Projection: test.a AS b, test.c\
             \n      TableScan: test\
@@ -2246,9 +2192,7 @@ mod tests {
             \n      TableScan: test\
             ";
 
-        assert_optimized_plan_eq(&plan, expected);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
@@ -2285,9 +2229,7 @@ mod tests {
         \n      Projection: sq.c\
         \n        TableScan: sq\
         \n    TableScan: test";
-        assert_optimized_plan_eq(&plan, expected_after);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected_after)
     }
 
     #[test]
@@ -2312,15 +2254,13 @@ mod tests {
         // Ensure that the predicate without any columns (0 = 1) is
         // still there.
         let expected_after = "Projection: b.a\
-        \n  Filter: b.a = Int64(1)\
-        \n    SubqueryAlias: b\
-        \n      Projection: b.a\
-        \n        SubqueryAlias: b\
-        \n          Projection: Int64(0) AS a\
+        \n  SubqueryAlias: b\
+        \n    Projection: b.a\
+        \n      SubqueryAlias: b\
+        \n        Projection: Int64(0) AS a\
+        \n          Filter: Int64(0) = Int64(1)\
         \n            EmptyRelation";
-        assert_optimized_plan_eq(&plan, expected_after);
-
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected_after)
     }
 
     #[test]
@@ -2351,7 +2291,13 @@ mod tests {
         \n        TableScan: test\
         \n    Projection: test1.a AS d, test1.a AS e\
         \n      TableScan: test1";
-        assert_optimized_plan_eq(&plan, expected);
-        Ok(())
+        assert_optimized_plan_eq(&plan, expected)?;
+
+        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
+        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
+        let optimized_plan = PushDownFilter::new()
+            .optimize(&plan, &mut OptimizerConfig::new())
+            .expect("failed to optimize plan");
+        assert_optimized_plan_eq(&optimized_plan, expected)
     }
 }
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index c4911439c..457ea833e 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -274,12 +274,12 @@ fn join_keys_in_subquery_alias() {
     let plan = test_sql(sql).unwrap();
     let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\
     \n  Inner Join: a.col_int32 = b.key\
-    \n    Filter: a.col_int32 IS NOT NULL\
-    \n      SubqueryAlias: a\
+    \n    SubqueryAlias: a\
+    \n      Filter: test.col_int32 IS NOT NULL\
     \n        TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\
-    \n    Filter: b.key IS NOT NULL\
-    \n      SubqueryAlias: b\
-    \n        Projection: test.col_int32 AS key\
+    \n    SubqueryAlias: b\
+    \n      Projection: test.col_int32 AS key\
+    \n        Filter: test.col_int32 IS NOT NULL\
     \n          TableScan: test projection=[col_int32]";
     assert_eq!(expected, format!("{:?}", plan));
 }
@@ -288,20 +288,19 @@ fn join_keys_in_subquery_alias() {
 fn join_keys_in_subquery_alias_1() {
     let sql = "SELECT * FROM test AS A, ( SELECT test.col_int32 AS key FROM test JOIN test AS C on test.col_int32 = C.col_int32 ) AS B where A.col_int32 = B.key;";
     let plan = test_sql(sql).unwrap();
-    let expected =  "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\
+    let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\
     \n  Inner Join: a.col_int32 = b.key\
-    \n    Filter: a.col_int32 IS NOT NULL\
-    \n      SubqueryAlias: a\
+    \n    SubqueryAlias: a\
+    \n      Filter: test.col_int32 IS NOT NULL\
     \n        TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\
-    \n    Filter: b.key IS NOT NULL\
-    \n      SubqueryAlias: b\
-    \n        Projection: test.col_int32 AS key\
-    \n          Inner Join: test.col_int32 = c.col_int32\
+    \n    SubqueryAlias: b\
+    \n      Projection: test.col_int32 AS key\
+    \n        Inner Join: test.col_int32 = c.col_int32\
+    \n          Filter: test.col_int32 IS NOT NULL\
+    \n            TableScan: test projection=[col_int32]\
+    \n          SubqueryAlias: c\
     \n            Filter: test.col_int32 IS NOT NULL\
-    \n              TableScan: test projection=[col_int32]\
-    \n            Filter: c.col_int32 IS NOT NULL\
-    \n              SubqueryAlias: c\
-    \n                TableScan: test projection=[col_int32]";
+    \n              TableScan: test projection=[col_int32]";
     assert_eq!(expected, format!("{:?}", plan));
 }