You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/30 19:51:08 UTC
[arrow-datafusion] branch master updated: reimplement `push_down_filter` to remove global-state (#4365)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 3fe542f2a reimplement `push_down_filter` to remove global-state (#4365)
3fe542f2a is described below
commit 3fe542f2afcd5360edc9abb7ad1e8243b560a6b2
Author: jakevin <ja...@gmail.com>
AuthorDate: Thu Dec 1 03:51:03 2022 +0800
reimplement `push_down_filter` to remove global-state (#4365)
* reimplement `push_down_filter`
* add comment
* fix union replace
* fix but when meet same name but different qualifier, and add ut.c
* add UT `filter_complex_agg` `test_union_different_schema`
* fix regression for push_down_filter meet subquery-alias
* polish comment
* add UT confirm that avoid duplicate Filters
* merge confirm UT
* remove TODO
---
benchmarks/expected-plans/q21.txt | 8 +-
benchmarks/expected-plans/q7.txt | 8 +-
datafusion/core/tests/sql/joins.rs | 13 +-
datafusion/optimizer/src/lib.rs | 2 +-
datafusion/optimizer/src/optimizer.rs | 5 +-
.../{filter_push_down.rs => push_down_filter.rs} | 1196 ++++++++++----------
datafusion/optimizer/tests/integration-test.rs | 31 +-
7 files changed, 604 insertions(+), 659 deletions(-)
diff --git a/benchmarks/expected-plans/q21.txt b/benchmarks/expected-plans/q21.txt
index 397e0a8d8..3ef6269de 100644
--- a/benchmarks/expected-plans/q21.txt
+++ b/benchmarks/expected-plans/q21.txt
@@ -7,8 +7,8 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST
Inner Join: l1.l_orderkey = orders.o_orderkey
Inner Join: supplier.s_suppkey = l1.l_suppkey
TableScan: supplier projection=[s_suppkey, s_name, s_nationkey]
- Filter: l1.l_receiptdate > l1.l_commitdate
- SubqueryAlias: l1
+ SubqueryAlias: l1
+ Filter: lineitem.l_receiptdate > lineitem.l_commitdate
TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate]
Filter: orders.o_orderstatus = Utf8("F")
TableScan: orders projection=[o_orderkey, o_orderstatus]
@@ -16,6 +16,6 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST
TableScan: nation projection=[n_nationkey, n_name]
SubqueryAlias: l2
TableScan: lineitem projection=[l_orderkey, l_suppkey]
- Filter: l3.l_receiptdate > l3.l_commitdate
- SubqueryAlias: l3
+ SubqueryAlias: l3
+ Filter: lineitem.l_receiptdate > lineitem.l_commitdate
TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate]
\ No newline at end of file
diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt
index 74857c6f9..53deda1b8 100644
--- a/benchmarks/expected-plans/q7.txt
+++ b/benchmarks/expected-plans/q7.txt
@@ -14,9 +14,9 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST,
TableScan: lineitem projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate]
TableScan: orders projection=[o_orderkey, o_custkey]
TableScan: customer projection=[c_custkey, c_nationkey]
- Filter: n1.n_name = Utf8("FRANCE") OR n1.n_name = Utf8("GERMANY")
- SubqueryAlias: n1
+ SubqueryAlias: n1
+ Filter: nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name]
- Filter: n2.n_name = Utf8("GERMANY") OR n2.n_name = Utf8("FRANCE")
- SubqueryAlias: n2
+ SubqueryAlias: n2
+ Filter: nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE")
TableScan: nation projection=[n_nationkey, n_name]
\ No newline at end of file
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 87fb594c7..7129fc7ed 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1636,15 +1636,14 @@ async fn reduce_left_join_3() -> Result<()> {
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t3.t1_id, t3.t1_name, t3.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Left Join: t3.t1_int = t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: t3.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ]
- ;
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index cdfe7fc9b..aba53a3d8 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -23,12 +23,12 @@ pub mod eliminate_filter;
pub mod eliminate_limit;
pub mod eliminate_outer_join;
pub mod filter_null_join_keys;
-pub mod filter_push_down;
pub mod inline_table_scan;
pub mod limit_push_down;
pub mod optimizer;
pub mod projection_push_down;
pub mod propagate_empty_relation;
+pub mod push_down_filter;
pub mod scalar_subquery_to_join;
pub mod simplify_expressions;
pub mod single_distinct_to_groupby;
diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs
index 5340f4f80..19eeecb0f 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -25,11 +25,11 @@ use crate::eliminate_filter::EliminateFilter;
use crate::eliminate_limit::EliminateLimit;
use crate::eliminate_outer_join::EliminateOuterJoin;
use crate::filter_null_join_keys::FilterNullJoinKeys;
-use crate::filter_push_down::FilterPushDown;
use crate::inline_table_scan::InlineTableScan;
use crate::limit_push_down::LimitPushDown;
use crate::projection_push_down::ProjectionPushDown;
use crate::propagate_empty_relation::PropagateEmptyRelation;
+use crate::push_down_filter::PushDownFilter;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::scalar_subquery_to_join::ScalarSubqueryToJoin;
use crate::simplify_expressions::SimplifyExpressions;
@@ -184,8 +184,9 @@ impl Optimizer {
rules.push(Arc::new(FilterNullJoinKeys::default()));
}
rules.push(Arc::new(EliminateOuterJoin::new()));
- rules.push(Arc::new(FilterPushDown::new()));
+ // Filters can't be pushed down past Limits, we should do PushDownFilter after LimitPushDown
rules.push(Arc::new(LimitPushDown::new()));
+ rules.push(Arc::new(PushDownFilter::new()));
rules.push(Arc::new(SingleDistinctToGroupBy::new()));
// The previous optimizations added expressions and projections,
diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/push_down_filter.rs
similarity index 72%
rename from datafusion/optimizer/src/filter_push_down.rs
rename to datafusion/optimizer/src/push_down_filter.rs
index 2f8a8a8b4..e59590df5 100644
--- a/datafusion/optimizer/src/filter_push_down.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -12,26 +12,25 @@
// specific language governing permissions and limitations
// under the License.
-//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan
+//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan
+use crate::utils::conjunction;
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
+use datafusion_expr::utils::exprlist_to_columns;
use datafusion_expr::{
- and, col,
- expr::BinaryExpr,
+ and,
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
- logical_plan::{
- Aggregate, CrossJoin, Join, JoinType, Limit, LogicalPlan, Projection, TableScan,
- Union,
- },
+ logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union},
or,
- utils::{expr_to_columns, exprlist_to_columns, from_plan},
- Expr, Operator, TableProviderFilterPushDown,
+ utils::from_plan,
+ BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown,
};
use std::collections::{HashMap, HashSet};
use std::iter::once;
+use std::sync::Arc;
-/// Filter Push Down optimizer rule pushes filter clauses down the plan
+/// Push Down Filter optimizer rule pushes filter clauses down the plan
/// # Introduction
/// A filter-commutative operation is an operation whose result of filter(op(data)) = op(filter(data)).
/// An example of a filter-commutative operation is a projection; a counter-example is `limit`.
@@ -57,96 +56,7 @@ use std::iter::once;
/// When it passes through a projection, it re-writes the filter's expression taking into account that projection.
/// When multiple filters would have been written, it `AND` their expressions into a single expression.
#[derive(Default)]
-pub struct FilterPushDown {}
-
-/// Filter predicate represented by tuple of expression and its columns
-type Predicate = (Expr, HashSet<Column>);
-
-/// Multiple filter predicates represented by tuple of expressions vector
-/// and corresponding expression columns vector
-type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<Column>>);
-
-#[derive(Debug, Clone, Default)]
-struct State {
- // (predicate, columns on the predicate)
- filters: Vec<Predicate>,
-}
-
-impl State {
- fn append_predicates(&mut self, predicates: Predicates) {
- predicates
- .0
- .into_iter()
- .zip(predicates.1)
- .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone())))
- }
-}
-
-/// returns all predicates in `state` that depend on any of `used_columns`
-/// or the ones that does not reference any columns (e.g. WHERE 1=1)
-fn get_predicates<'a>(
- state: &'a State,
- used_columns: &HashSet<Column>,
-) -> Predicates<'a> {
- state
- .filters
- .iter()
- .filter(|(_, columns)| {
- columns.is_empty()
- || !columns
- .intersection(used_columns)
- .collect::<HashSet<_>>()
- .is_empty()
- })
- .map(|&(ref a, ref b)| (a, b))
- .unzip()
-}
-
-/// Optimizes the plan
-fn push_down(state: &State, plan: &LogicalPlan) -> Result<LogicalPlan> {
- let new_inputs = plan
- .inputs()
- .iter()
- .map(|input| optimize(input, state.clone()))
- .collect::<Result<Vec<_>>>()?;
-
- let expr = plan.expressions();
- from_plan(plan, &expr, &new_inputs)
-}
-
-// remove all filters from `filters` that are in `predicate_columns`
-fn remove_filters(
- filters: &[Predicate],
- predicate_columns: &[&HashSet<Column>],
-) -> Vec<Predicate> {
- filters
- .iter()
- .filter(|(_, columns)| !predicate_columns.contains(&columns))
- .cloned()
- .collect::<Vec<_>>()
-}
-
-/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters
-/// in `state` depend on the columns `used_columns`.
-fn issue_filters(
- mut state: State,
- used_columns: HashSet<Column>,
- plan: &LogicalPlan,
-) -> Result<LogicalPlan> {
- let (predicates, predicate_columns) = get_predicates(&state, &used_columns);
-
- if predicates.is_empty() {
- // all filters can be pushed down => optimize inputs and return new plan
- return push_down(&state, plan);
- }
-
- let plan = utils::add_filter(plan.clone(), &predicates)?;
-
- state.filters = remove_filters(&state.filters, &predicate_columns);
-
- // continue optimization over all input nodes by cloning the current state (i.e. each node is independent)
- push_down(&state, &plan)
-}
+pub struct PushDownFilter {}
// For a given JOIN logical plan, determine whether each side of the join is preserved.
// We say a join side is preserved if the join returns all or a subset of the rows from
@@ -220,15 +130,7 @@ fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
// or not the side's rows are preserved when joining. If the side is not preserved, we
// do not push down anything. Otherwise we can push down predicates where all of the
// relevant columns are contained on the relevant join side's schema.
-fn get_pushable_join_predicates<'a>(
- filters: &'a [Predicate],
- schema: &DFSchema,
- preserved: bool,
-) -> Predicates<'a> {
- if !preserved {
- return (vec![], vec![]);
- }
-
+fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result<bool> {
let schema_columns = schema
.fields()
.iter()
@@ -240,19 +142,13 @@ fn get_pushable_join_predicates<'a>(
]
})
.collect::<HashSet<_>>();
+ let columns = predicate.to_columns()?;
- filters
- .iter()
- .filter(|(_, columns)| {
- let all_columns_in_schema = schema_columns
- .intersection(columns)
- .collect::<HashSet<_>>()
- .len()
- == columns.len();
- all_columns_in_schema
- })
- .map(|(a, b)| (a, b))
- .unzip()
+ Ok(schema_columns
+ .intersection(&columns)
+ .collect::<HashSet<_>>()
+ .len()
+ == columns.len())
}
// examine OR clause to see if any useful clauses can be extracted and push down.
@@ -292,9 +188,9 @@ fn extract_or_clauses_for_join(
filters: &[&Expr],
schema: &DFSchema,
preserved: bool,
-) -> (Vec<Expr>, Vec<HashSet<Column>>) {
+) -> Vec<Expr> {
if !preserved {
- return (vec![], vec![]);
+ return vec![];
}
let schema_columns = schema
@@ -310,7 +206,6 @@ fn extract_or_clauses_for_join(
.collect::<HashSet<_>>();
let mut exprs = vec![];
- let mut expr_columns = vec![];
for expr in filters.iter() {
if let Expr::BinaryExpr(BinaryExpr {
left,
@@ -323,17 +218,13 @@ fn extract_or_clauses_for_join(
// If nothing can be extracted from any sub clauses, do nothing for this OR clause.
if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
- let predicate = or(left_expr, right_expr);
- let columns = predicate.to_columns().ok().unwrap();
-
- exprs.push(predicate);
- expr_columns.push(columns);
+ exprs.push(or(left_expr, right_expr));
}
}
}
// new formed OR clauses and their column references
- (exprs, expr_columns)
+ exprs
}
// extract qual from OR sub-clause.
@@ -403,94 +294,89 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
predicate
}
-fn optimize_join(
- mut state: State,
+// push down join/cross-join
+fn push_down_all_join(
+ predicates: Vec<Expr>,
plan: &LogicalPlan,
left: &LogicalPlan,
right: &LogicalPlan,
- on_filter: Vec<Predicate>,
+ on_filter: Vec<Expr>,
) -> Result<LogicalPlan> {
+ let on_filter_empty = on_filter.is_empty();
// Get pushable predicates from current optimizer state
let (left_preserved, right_preserved) = lr_is_preserved(plan)?;
- let to_left =
- get_pushable_join_predicates(&state.filters, left.schema(), left_preserved);
- let to_right =
- get_pushable_join_predicates(&state.filters, right.schema(), right_preserved);
- let to_keep: Predicates = state
- .filters
- .iter()
- .filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e))
- .map(|(a, b)| (a, b))
- .unzip();
+ let mut left_push = vec![];
+ let mut right_push = vec![];
+
+ let mut keep_predicates = vec![];
+ for predicate in predicates {
+ if left_preserved && can_pushdown_join_predicate(&predicate, left.schema())? {
+ left_push.push(predicate);
+ } else if right_preserved
+ && can_pushdown_join_predicate(&predicate, right.schema())?
+ {
+ right_push.push(predicate);
+ } else {
+ keep_predicates.push(predicate);
+ }
+ }
- // Get pushable predicates from join filter
- let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() {
- ((vec![], vec![]), (vec![], vec![]), vec![])
- } else {
+ let mut keep_condition = vec![];
+ if !on_filter.is_empty() {
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan)?;
- let on_to_left =
- get_pushable_join_predicates(&on_filter, left.schema(), on_left_preserved);
- let on_to_right =
- get_pushable_join_predicates(&on_filter, right.schema(), on_right_preserved);
- let on_to_keep = on_filter
- .iter()
- .filter(|(e, _)| !on_to_left.0.contains(&e) && !on_to_right.0.contains(&e))
- .map(|(a, _)| a.clone())
- .collect::<Vec<_>>();
-
- (on_to_left, on_to_right, on_to_keep)
- };
+ for on in on_filter {
+ if on_left_preserved && can_pushdown_join_predicate(&on, left.schema())? {
+ left_push.push(on)
+ } else if on_right_preserved
+ && can_pushdown_join_predicate(&on, right.schema())?
+ {
+ right_push.push(on)
+ } else {
+ keep_condition.push(on)
+ }
+ }
+ }
// Extract from OR clause, generate new predicates for both side of join if possible.
// We only track the unpushable predicates above.
- let or_to_left =
- extract_or_clauses_for_join(&to_keep.0, left.schema(), left_preserved);
- let or_to_right =
- extract_or_clauses_for_join(&to_keep.0, right.schema(), right_preserved);
+ let or_to_left = extract_or_clauses_for_join(
+ &keep_predicates.iter().collect::<Vec<_>>(),
+ left.schema(),
+ left_preserved,
+ );
+ let or_to_right = extract_or_clauses_for_join(
+ &keep_predicates.iter().collect::<Vec<_>>(),
+ right.schema(),
+ right_preserved,
+ );
let on_or_to_left = extract_or_clauses_for_join(
- &on_to_keep.iter().collect::<Vec<_>>(),
+ &keep_condition.iter().collect::<Vec<_>>(),
left.schema(),
left_preserved,
);
let on_or_to_right = extract_or_clauses_for_join(
- &on_to_keep.iter().collect::<Vec<_>>(),
+ &keep_condition.iter().collect::<Vec<_>>(),
right.schema(),
right_preserved,
);
- // Build new filter states using pushable predicates
- // from current optimizer states and from ON clause.
- // Then recursively call optimization for both join inputs
- let mut left_state = State::default();
- left_state.append_predicates(to_left);
- left_state.append_predicates(on_to_left);
- or_to_left
- .0
- .into_iter()
- .zip(or_to_left.1)
- .for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
- on_or_to_left
- .0
- .into_iter()
- .zip(on_or_to_left.1)
- .for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
- let left = optimize(left, left_state)?;
-
- let mut right_state = State::default();
- right_state.append_predicates(to_right);
- right_state.append_predicates(on_to_right);
- or_to_right
- .0
- .into_iter()
- .zip(or_to_right.1)
- .for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
- on_or_to_right
- .0
- .into_iter()
- .zip(on_or_to_right.1)
- .for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
- let right = optimize(right, right_state)?;
+ left_push.extend(or_to_left);
+ left_push.extend(on_or_to_left);
+ right_push.extend(or_to_right);
+ right_push.extend(on_or_to_right);
+ let left = match conjunction(left_push) {
+ Some(predicate) => {
+ LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?)
+ }
+ None => left.clone(),
+ };
+ let right = match conjunction(right_push) {
+ Some(predicate) => {
+ LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?)
+ }
+ None => right.clone(),
+ };
// Create a new Join with the new `left` and `right`
//
// expressions() output for Join is a vector consisting of
@@ -500,302 +386,359 @@ fn optimize_join(
// vector will contain only join keys (without additional
// element representing filter).
let expr = plan.expressions();
- let expr = if !on_filter.is_empty() && on_to_keep.is_empty() {
+ let expr = if !on_filter_empty && keep_condition.is_empty() {
// New filter expression is None - should remove last element
expr[..expr.len() - 1].to_vec()
- } else if !on_to_keep.is_empty() {
+ } else if !keep_condition.is_empty() {
// Replace last element with new filter expression
expr[..expr.len() - 1]
.iter()
.cloned()
- .chain(once(on_to_keep.into_iter().reduce(Expr::and).unwrap()))
+ .chain(once(keep_condition.into_iter().reduce(Expr::and).unwrap()))
.collect()
} else {
plan.expressions()
};
let plan = from_plan(plan, &expr, &[left, right])?;
- if to_keep.0.is_empty() {
+ if keep_predicates.is_empty() {
Ok(plan)
} else {
// wrap the join on the filter whose predicates must be kept
- let plan = utils::add_filter(plan, &to_keep.0)?;
- state.filters = remove_filters(&state.filters, &to_keep.1);
-
- Ok(plan)
+ match conjunction(keep_predicates) {
+ Some(predicate) => Ok(LogicalPlan::Filter(Filter::try_new(
+ predicate,
+ Arc::new(plan),
+ )?)),
+ None => Ok(plan),
+ }
}
}
-fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
- match plan {
- LogicalPlan::Explain { .. } => {
- // push the optimization to the plan of this explain
- push_down(&state, plan)
- }
- LogicalPlan::Analyze { .. } => push_down(&state, plan),
- LogicalPlan::Filter(filter) => {
- let predicate = utils::cnf_rewrite(filter.predicate().clone());
-
- utils::split_conjunction_owned(predicate)
- .into_iter()
- .try_for_each::<_, Result<()>>(|predicate| {
- let columns = predicate.to_columns()?;
- state.filters.push((predicate, columns));
- Ok(())
- })?;
-
- optimize(filter.input(), state)
+fn push_down_join(
+ plan: &LogicalPlan,
+ join: &Join,
+ parent_predicate: Option<&Expr>,
+) -> Result<Option<LogicalPlan>> {
+ let mut predicates = match parent_predicate {
+ Some(parent_predicate) => {
+ utils::split_conjunction_owned(utils::cnf_rewrite(parent_predicate.clone()))
}
- LogicalPlan::Projection(Projection {
- input,
- expr,
- schema,
- }) => {
- // A projection is filter-commutable, but re-writes all predicate expressions
- // collect projection.
- let projection = schema
- .fields()
- .iter()
- .enumerate()
- .flat_map(|(i, field)| {
- // strip alias, as they should not be part of filters
- let expr = match &expr[i] {
- Expr::Alias(expr, _) => expr.as_ref().clone(),
- expr => expr.clone(),
+ None => vec![],
+ };
+
+ // Convert JOIN ON predicate to Predicates
+ let on_filters = join
+ .filter
+ .as_ref()
+ .map(|e| utils::split_conjunction_owned(e.clone()))
+ .unwrap_or_else(Vec::new);
+
+ if join.join_type == JoinType::Inner {
+ // For inner joins, duplicate filters for joined columns so filters can be pushed down
+ // to both sides. Take the following query as an example:
+ //
+ // ```sql
+ // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
+ // ```
+ //
+ // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
+ // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
+ //
+ // Join clauses with `Using` constraints also take advantage of this logic to make sure
+ // predicates reference the shared join columns are pushed to both sides.
+ // This logic should also been applied to conditions in JOIN ON clause
+ let join_side_filters = predicates
+ .iter()
+ .chain(on_filters.iter())
+ .filter_map(|predicate| {
+ let mut join_cols_to_replace = HashMap::new();
+ let columns = match predicate.to_columns() {
+ Ok(columns) => columns,
+ Err(e) => return Some(Err(e)),
+ };
+
+ for col in columns.iter() {
+ for (l, r) in join.on.iter() {
+ if col == l {
+ join_cols_to_replace.insert(col, r);
+ break;
+ } else if col == r {
+ join_cols_to_replace.insert(col, l);
+ break;
+ }
+ }
+ }
+
+ if join_cols_to_replace.is_empty() {
+ return None;
+ }
+
+ let join_side_predicate =
+ match replace_col(predicate.clone(), &join_cols_to_replace) {
+ Ok(p) => p,
+ Err(e) => {
+ return Some(Err(e));
+ }
};
- // Convert both qualified and unqualified fields
- [
- (field.name().clone(), expr.clone()),
- (field.qualified_name(), expr),
- ]
- })
- .collect::<HashMap<_, _>>();
+ Some(Ok(join_side_predicate))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ predicates.extend(join_side_filters);
+ }
+ if on_filters.is_empty() && predicates.is_empty() {
+ return Ok(None);
+ }
+ Ok(Some(push_down_all_join(
+ predicates,
+ plan,
+ &join.left,
+ &join.right,
+ on_filters,
+ )?))
+}
- // re-write all filters based on this projection
- // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
- for (predicate, columns) in state.filters.iter_mut() {
- *predicate = replace_cols_by_name(predicate.clone(), &projection)?;
+impl OptimizerRule for PushDownFilter {
+ fn name(&self) -> &str {
+ "push_down_filter"
+ }
- columns.clear();
- expr_to_columns(predicate, columns)?;
+ fn optimize(
+ &self,
+ plan: &LogicalPlan,
+ optimizer_config: &mut OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ let filter = match plan {
+ LogicalPlan::Filter(filter) => filter,
+ // we also need to pushdown filter in Join.
+ LogicalPlan::Join(join) => {
+ let optimized_plan = push_down_join(plan, join, None)?;
+ return match optimized_plan {
+ Some(optimized_plan) => {
+ utils::optimize_children(self, &optimized_plan, optimizer_config)
+ }
+ None => utils::optimize_children(self, plan, optimizer_config),
+ };
}
+ _ => return utils::optimize_children(self, plan, optimizer_config),
+ };
- // optimize inner
- let new_input = optimize(input, state)?;
- Ok(from_plan(plan, expr, &[new_input])?)
- }
- LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => {
- // An aggregate's aggreagate columns are _not_ filter-commutable => collect these:
- // * columns whose aggregation expression depends on
- // * the aggregation columns themselves
-
- // construct set of columns that `aggr_expr` depends on
- let mut used_columns = HashSet::new();
- exprlist_to_columns(aggr_expr, &mut used_columns)?;
-
- let agg_columns = aggr_expr
- .iter()
- .map(|x| Ok(Column::from_name(x.display_name()?)))
- .collect::<Result<HashSet<_>>>()?;
- used_columns.extend(agg_columns);
-
- issue_filters(state, used_columns, plan)
- }
- LogicalPlan::Sort { .. } => {
- // sort is filter-commutable
- push_down(&state, plan)
- }
- LogicalPlan::Union(Union { inputs: _, schema }) => {
- // union changing all qualifiers while building logical plan so we need
- // to rewrite filters to push unqualified columns to inputs
- let projection = schema
- .fields()
- .iter()
- .map(|field| (field.qualified_name(), col(field.name())))
- .collect::<HashMap<_, _>>();
-
- // rewriting predicate expressions using unqualified names as replacements
- if !projection.is_empty() {
- for (predicate, columns) in state.filters.iter_mut() {
- *predicate = replace_cols_by_name(predicate.clone(), &projection)?;
-
- columns.clear();
- expr_to_columns(predicate, columns)?;
+ let child_plan = &**filter.input();
+ let new_plan = match child_plan {
+ LogicalPlan::Filter(child_filter) => {
+ let new_predicate =
+ and(filter.predicate().clone(), child_filter.predicate().clone());
+ let new_plan = LogicalPlan::Filter(Filter::try_new(
+ new_predicate,
+ child_filter.input().clone(),
+ )?);
+ return self.optimize(&new_plan, optimizer_config);
+ }
+ LogicalPlan::Repartition(_)
+ | LogicalPlan::Distinct(_)
+ | LogicalPlan::Sort(_) => {
+ // commutable
+ let new_filter =
+ plan.with_new_inputs(&[
+ (**(child_plan.inputs().get(0).unwrap())).clone()
+ ])?;
+ child_plan.with_new_inputs(&[new_filter])?
+ }
+ LogicalPlan::SubqueryAlias(subquery_alias) => {
+ let mut replace_map = HashMap::new();
+ for (i, field) in
+ subquery_alias.input.schema().fields().iter().enumerate()
+ {
+ replace_map.insert(
+ subquery_alias
+ .schema
+ .fields()
+ .get(i)
+ .unwrap()
+ .qualified_name(),
+ Expr::Column(field.qualified_column()),
+ );
}
+ let new_predicate =
+ replace_cols_by_name(filter.predicate().clone(), &replace_map)?;
+ let new_filter = LogicalPlan::Filter(Filter::try_new(
+ new_predicate,
+ subquery_alias.input.clone(),
+ )?);
+ child_plan.with_new_inputs(&[new_filter])?
}
-
- push_down(&state, plan)
- }
- LogicalPlan::Limit(Limit { input, .. }) => {
- // limit is _not_ filter-commutable => collect all columns from its input
- let used_columns = input
- .schema()
- .fields()
- .iter()
- .map(|f| f.qualified_column())
- .collect::<HashSet<_>>();
- issue_filters(state, used_columns, plan)
- }
- LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
- optimize_join(state, plan, left, right, vec![])
- }
- LogicalPlan::Join(Join {
- left,
- right,
- on,
- filter,
- join_type,
- ..
- }) => {
- // Convert JOIN ON predicate to Predicates
- let on_filters = filter
- .as_ref()
- .map(|e| {
- let predicates = utils::split_conjunction(e);
-
- predicates
- .into_iter()
- .map(|e| Ok((e.clone(), e.to_columns()?)))
- .collect::<Result<Vec<_>>>()
- })
- .unwrap_or_else(|| Ok(vec![]))?;
-
- if *join_type == JoinType::Inner {
- // For inner joins, duplicate filters for joined columns so filters can be pushed down
- // to both sides. Take the following query as an example:
- //
- // ```sql
- // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
- // ```
- //
- // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
- // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
- //
- // Join clauses with `Using` constraints also take advantage of this logic to make sure
- // predicates reference the shared join columns are pushed to both sides.
- // This logic should also been applied to conditions in JOIN ON clause
- let join_side_filters = state
- .filters
+ LogicalPlan::Projection(projection) => {
+ // A projection is filter-commutable, but re-writes all predicate expressions
+ // collect projection.
+ let replace_map = projection
+ .schema
+ .fields()
.iter()
- .chain(on_filters.iter())
- .filter_map(|(predicate, columns)| {
- let mut join_cols_to_replace = HashMap::new();
- for col in columns.iter() {
- for (l, r) in on {
- if col == l {
- join_cols_to_replace.insert(col, r);
- break;
- } else if col == r {
- join_cols_to_replace.insert(col, l);
- break;
- }
- }
- }
+ .enumerate()
+ .map(|(i, field)| {
+ // strip alias, as they should not be part of filters
+ let expr = match &projection.expr[i] {
+ Expr::Alias(expr, _) => expr.as_ref().clone(),
+ expr => expr.clone(),
+ };
+
+ (field.qualified_name(), expr)
+ })
+ .collect::<HashMap<_, _>>();
- if join_cols_to_replace.is_empty() {
- return None;
- }
+ // re-write all filters based on this projection
+ // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
+ let new_filter = LogicalPlan::Filter(Filter::try_new(
+ replace_cols_by_name(filter.predicate().clone(), &replace_map)?,
+ projection.input.clone(),
+ )?);
- let join_side_predicate =
- match replace_col(predicate.clone(), &join_cols_to_replace) {
- Ok(p) => p,
- Err(e) => {
- return Some(Err(e));
- }
- };
-
- let join_side_columns = columns
- .clone()
- .into_iter()
- // replace keys in join_cols_to_replace with values in resulting column
- // set
- .filter(|c| !join_cols_to_replace.contains_key(c))
- .chain(join_cols_to_replace.values().map(|v| (*v).clone()))
- .collect();
-
- Some(Ok((join_side_predicate, join_side_columns)))
- })
- .collect::<Result<Vec<_>>>()?;
- state.filters.extend(join_side_filters);
+ child_plan.with_new_inputs(&[new_filter])?
}
+ LogicalPlan::Union(union) => {
+ let mut inputs = Vec::with_capacity(union.inputs.len());
+ for input in &union.inputs {
+ let mut replace_map = HashMap::new();
+ for (i, field) in input.schema().fields().iter().enumerate() {
+ replace_map.insert(
+ union.schema.fields().get(i).unwrap().qualified_name(),
+ Expr::Column(field.qualified_column()),
+ );
+ }
- optimize_join(state, plan, left, right, on_filters)
- }
- LogicalPlan::TableScan(TableScan {
- source,
- projected_schema,
- filters,
- projection,
- table_name,
- fetch,
- }) => {
- let mut used_columns = HashSet::new();
- let mut new_filters = filters.clone();
-
- for (filter_expr, cols) in &state.filters {
- let (preserve_filter_node, add_to_provider) =
- match source.supports_filter_pushdown(filter_expr)? {
- TableProviderFilterPushDown::Unsupported => (true, false),
- TableProviderFilterPushDown::Inexact => (true, true),
- TableProviderFilterPushDown::Exact => (false, true),
- };
-
- if preserve_filter_node {
- used_columns.extend(cols.clone());
+ let push_predicate =
+ replace_cols_by_name(filter.predicate().clone(), &replace_map)?;
+ inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
+ push_predicate,
+ input.clone(),
+ )?)))
}
-
- if add_to_provider {
- // Don't add expression again if it's already present in
- // pushed down filters.
- if new_filters.contains(filter_expr) {
- continue;
+ LogicalPlan::Union(Union {
+ inputs,
+ schema: plan.schema().clone(),
+ })
+ }
+ LogicalPlan::Aggregate(agg) => {
+ // An aggregate's aggregate columns are _not_ filter-commutable => collect these:
+ // * columns whose aggregation expression depends on
+ // * the aggregation columns themselves
+
+ // construct set of columns that `aggr_expr` depends on
+ let mut used_columns = HashSet::new();
+ exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?;
+ let agg_columns = agg
+ .aggr_expr
+ .iter()
+ .map(|x| Ok(Column::from_name(x.display_name()?)))
+ .collect::<Result<HashSet<_>>>()?;
+ used_columns.extend(agg_columns);
+
+ let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
+ filter.predicate().clone(),
+ ));
+
+ let mut keep_predicates = vec![];
+ let mut push_predicates = vec![];
+ for expr in predicates {
+ let columns = expr.to_columns()?;
+ if columns.is_empty()
+ || !columns
+ .intersection(&used_columns)
+ .collect::<HashSet<_>>()
+ .is_empty()
+ {
+ keep_predicates.push(expr);
+ } else {
+ push_predicates.push(expr);
}
- new_filters.push(filter_expr.clone());
+ }
+
+ let child = match conjunction(push_predicates) {
+ Some(predicate) => LogicalPlan::Filter(Filter::try_new(
+ predicate,
+ Arc::new((*agg.input).clone()),
+ )?),
+ None => (*agg.input).clone(),
+ };
+ let new_agg = from_plan(
+ filter.input(),
+ &filter.input().expressions(),
+ &vec![child],
+ )?;
+ match conjunction(keep_predicates) {
+ Some(predicate) => LogicalPlan::Filter(Filter::try_new(
+ predicate,
+ Arc::new(new_agg),
+ )?),
+ None => new_agg,
+ }
+ }
+ LogicalPlan::Join(join) => {
+ match push_down_join(filter.input(), join, Some(filter.predicate()))? {
+ Some(optimized_plan) => optimized_plan,
+ None => plan.clone(),
}
}
+ LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
+ let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
+ filter.predicate().clone(),
+ ));
- issue_filters(
- state,
- used_columns,
- &LogicalPlan::TableScan(TableScan {
- source: source.clone(),
- projection: projection.clone(),
- projected_schema: projected_schema.clone(),
- table_name: table_name.clone(),
- filters: new_filters,
- fetch: *fetch,
- }),
- )
- }
- _ => {
- // all other plans are _not_ filter-commutable
- let used_columns = plan
- .schema()
- .fields()
- .iter()
- .map(|f| f.qualified_column())
- .collect::<HashSet<_>>();
- issue_filters(state, used_columns, plan)
- }
- }
-}
+ push_down_all_join(predicates, filter.input(), left, right, vec![])?
+ }
+ LogicalPlan::TableScan(scan) => {
+ let mut new_scan_filters = scan.filters.clone();
+ let mut new_predicate = vec![];
+
+ let filter_predicates = utils::split_conjunction_owned(
+ utils::cnf_rewrite(filter.predicate().clone()),
+ );
+
+ for filter_expr in &filter_predicates {
+ let (preserve_filter_node, add_to_provider) =
+ match scan.source.supports_filter_pushdown(filter_expr)? {
+ TableProviderFilterPushDown::Unsupported => (true, false),
+ TableProviderFilterPushDown::Inexact => (true, true),
+ TableProviderFilterPushDown::Exact => (false, true),
+ };
+ if preserve_filter_node {
+ new_predicate.push(filter_expr.clone());
+ }
+ if add_to_provider {
+ // avoid reduplicated filter expr.
+ if new_scan_filters.contains(filter_expr) {
+ continue;
+ }
+ new_scan_filters.push(filter_expr.clone());
+ }
+ }
-impl OptimizerRule for FilterPushDown {
- fn name(&self) -> &str {
- "filter_push_down"
- }
+ let new_scan = LogicalPlan::TableScan(TableScan {
+ source: scan.source.clone(),
+ projection: scan.projection.clone(),
+ projected_schema: scan.projected_schema.clone(),
+ table_name: scan.table_name.clone(),
+ filters: new_scan_filters,
+ fetch: scan.fetch,
+ });
+
+ match conjunction(new_predicate) {
+ Some(predicate) => LogicalPlan::Filter(Filter::try_new(
+ predicate,
+ Arc::new(new_scan),
+ )?),
+ None => new_scan,
+ }
+ }
+ _ => plan.clone(),
+ };
- fn optimize(
- &self,
- plan: &LogicalPlan,
- _: &mut OptimizerConfig,
- ) -> Result<LogicalPlan> {
- optimize(plan, State::default())
+ utils::optimize_children(self, &new_plan, optimizer_config)
}
}
-impl FilterPushDown {
+impl PushDownFilter {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
@@ -828,25 +771,24 @@ fn replace_cols_by_name(e: Expr, replace_map: &HashMap<String, Expr>) -> Result<
mod tests {
use super::*;
use crate::test::*;
- use arrow::datatypes::SchemaRef;
+ use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::DFSchema;
+ use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
- and, col, in_list, in_subquery, lit, logical_plan::JoinType, sum, Expr,
- LogicalPlanBuilder, Operator, TableSource, TableType,
+ and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr,
+ Expr, LogicalPlanBuilder, Operator, TableSource, TableType,
};
use std::sync::Arc;
- fn optimize_plan(plan: &LogicalPlan) -> LogicalPlan {
- let rule = FilterPushDown::new();
- rule.optimize(plan, &mut OptimizerConfig::new())
- .expect("failed to optimize plan")
- }
-
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
- let optimized_plan = optimize_plan(plan);
+ fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ let optimized_plan = PushDownFilter::new()
+ .optimize(plan, &mut OptimizerConfig::new())
+ .expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
- assert_eq!(formatted_plan, expected);
+ assert_eq!(plan.schema(), optimized_plan.schema());
+ assert_eq!(expected, formatted_plan);
+ Ok(())
}
#[test]
@@ -861,8 +803,7 @@ mod tests {
Projection: test.a, test.b\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -879,8 +820,7 @@ mod tests {
\n Limit: skip=0, fetch=10\
\n Projection: test.a, test.b\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -892,8 +832,7 @@ mod tests {
let expected = "\
Filter: Int64(0) = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -910,8 +849,7 @@ mod tests {
\n Projection: test.a, test.b, test.c\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -926,8 +864,20 @@ mod tests {
Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\
\n Filter: test.a > Int64(10)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn filter_complex_group_by() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
+ .filter(col("b").gt(lit(10i64)))?
+ .build()?;
+ let expected = "Filter: test.b > Int64(10)\
+ \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\
+ \n TableScan: test";
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -942,8 +892,7 @@ mod tests {
Filter: b > Int64(10)\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -966,8 +915,7 @@ mod tests {
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
\n Filter: test.c = Int64(1) OR test.c = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
@@ -983,8 +931,7 @@ mod tests {
Projection: test.a AS b, test.c\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
fn add(left: Expr, right: Expr) -> Expr {
@@ -1029,8 +976,7 @@ mod tests {
Projection: test.a * Int32(2) + test.c AS b, test.c\
\n Filter: test.a * Int32(2) + test.c = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
@@ -1063,8 +1009,7 @@ mod tests {
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
\n Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
@@ -1098,9 +1043,7 @@ mod tests {
\n Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
@@ -1135,9 +1078,7 @@ mod tests {
\n Projection: test.a AS b, test.c\
\n Filter: test.a > Int64(10)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// verifies that when two limits are in place, we jump neither
@@ -1159,26 +1100,24 @@ mod tests {
\n Limit: skip=0, fetch=20\
\n Projection: test.a, test.b\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
fn union_all() -> Result<()> {
let table_scan = test_table_scan()?;
- let plan = LogicalPlanBuilder::from(table_scan.clone())
- .union(LogicalPlanBuilder::from(table_scan).build()?)?
+ let table_scan2 = test_table_scan_with_name("test2")?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .union(LogicalPlanBuilder::from(table_scan2).build()?)?
.filter(col("a").eq(lit(1i64)))?
.build()?;
// filter appears below Union
- let expected = "\
- Union\
- \n Filter: a = Int64(1)\
- \n TableScan: test\
- \n Filter: a = Int64(1)\
- \n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ let expected = "Union\
+ \n Filter: test.a = Int64(1)\
+ \n TableScan: test\
+ \n Filter: test2.a = Int64(1)\
+ \n TableScan: test2";
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -1193,8 +1132,7 @@ mod tests {
.build()?;
// filter appears below Union
- let expected = "Union\
- \n SubqueryAlias: test2\
+ let expected = "Union\n SubqueryAlias: test2\
\n Projection: test.a AS b\
\n Filter: test.a = Int64(1)\
\n TableScan: test\
@@ -1202,8 +1140,68 @@ mod tests {
\n Projection: test.a AS b\
\n Filter: test.a = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn test_union_different_schema() -> Result<()> {
+ let left = LogicalPlanBuilder::from(test_table_scan()?)
+ .project(vec![col("a"), col("b"), col("c")])?
+ .build()?;
+
+ let schema = Schema::new(vec![
+ Field::new("d", DataType::UInt32, false),
+ Field::new("e", DataType::UInt32, false),
+ Field::new("f", DataType::UInt32, false),
+ ]);
+ let right = table_scan(Some("test1"), &schema, None)?
+ .project(vec![col("d"), col("e"), col("f")])?
+ .build()?;
+ let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
+ let plan = LogicalPlanBuilder::from(left)
+ .cross_join(&right)?
+ .project(vec![col("test.a"), col("test1.d")])?
+ .filter(filter)?
+ .build()?;
+
+ let expected = "Projection: test.a, test1.d\
+ \n CrossJoin:\
+ \n Projection: test.a, test.b, test.c\
+ \n Filter: test.a = Int32(1)\
+ \n TableScan: test\
+ \n Projection: test1.d, test1.e, test1.f\
+ \n Filter: test1.d > Int32(2)\
+ \n TableScan: test1";
+
+ assert_optimized_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn test_project_same_name_different_qualifier() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let left = LogicalPlanBuilder::from(table_scan)
+ .project(vec![col("a"), col("b"), col("c")])?
+ .build()?;
+ let right_table_scan = test_table_scan_with_name("test1")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a"), col("b"), col("c")])?
+ .build()?;
+ let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
+ let plan = LogicalPlanBuilder::from(left)
+ .cross_join(&right)?
+ .project(vec![col("test.a"), col("test1.a")])?
+ .filter(filter)?
+ .build()?;
+
+ let expected = "Projection: test.a, test1.a\
+ \n CrossJoin:\
+ \n Projection: test.a, test.b, test.c\
+ \n Filter: test.a = Int32(1)\
+ \n TableScan: test\
+ \n Projection: test1.a, test1.b, test1.c\
+ \n Filter: test1.a > Int32(2)\
+ \n TableScan: test1";
+ assert_optimized_plan_eq(&plan, expected)
}
/// verifies that filters with the same columns are correctly placed
@@ -1238,8 +1236,7 @@ mod tests {
\n Filter: test.a <= Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// verifies that filters to be placed on the same depth are ANDed
@@ -1269,8 +1266,7 @@ mod tests {
\n Limit: skip=0, fetch=1\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// verifies that filters on a plan with user nodes are not lost
@@ -1292,8 +1288,7 @@ mod tests {
// not part of the test
assert_eq!(format!("{:?}", plan), expected);
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// post-on-join predicates on a column common to both sides is pushed to both sides
@@ -1318,8 +1313,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: test.a <= Int64(1)\
+ "Filter: test.a <= Int64(1)\
\n Inner Join: test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
@@ -1334,8 +1328,7 @@ mod tests {
\n Projection: test2.a\
\n Filter: test2.a <= Int64(1)\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// post-using-join predicates on a column common to both sides is pushed to both sides
@@ -1359,8 +1352,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: test.a <= Int64(1)\
+ "Filter: test.a <= Int64(1)\
\n Inner Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
@@ -1375,8 +1367,7 @@ mod tests {
\n Projection: test2.a\
\n Filter: test2.a <= Int64(1)\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// post-join predicates with columns from both sides are not pushed
@@ -1404,8 +1395,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: test.c <= test2.b\
+ "Filter: test.c <= test2.b\
\n Inner Join: test.a = test2.a\
\n Projection: test.a, test.c\
\n TableScan: test\
@@ -1415,8 +1405,7 @@ mod tests {
// expected is equal: no push-down
let expected = &format!("{:?}", plan);
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// post-join predicates with columns from one side of a join are pushed only to that side
@@ -1444,8 +1433,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: test.b <= Int64(1)\
+ "Filter: test.b <= Int64(1)\
\n Inner Join: test.a = test2.a\
\n Projection: test.a, test.b\
\n TableScan: test\
@@ -1460,8 +1448,7 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a, test2.c\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// post-join predicates on the right side of a left join are not duplicated
@@ -1486,8 +1473,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: test2.a <= Int64(1)\
+ "Filter: test2.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
@@ -1501,12 +1487,10 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// post-join predicates on the left side of a right join are not duplicated
- /// TODO: In this case we can sometimes convert the join to an INNER join
#[test]
fn filter_using_right_join() -> Result<()> {
let table_scan = test_table_scan()?;
@@ -1527,8 +1511,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: test.a <= Int64(1)\
+ "Filter: test.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
@@ -1542,8 +1525,7 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// post-left-join predicate on a column common to both sides is only pushed to the left side
@@ -1568,8 +1550,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: test.a <= Int64(1)\
+ "Filter: test.a <= Int64(1)\
\n Left Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
@@ -1583,8 +1564,7 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// post-right-join predicate on a column common to both sides is only pushed to the right side
@@ -1609,8 +1589,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: test2.a <= Int64(1)\
+ "Filter: test2.a <= Int64(1)\
\n Right Join: Using test.a = test2.a\
\n TableScan: test\
\n Projection: test2.a\
@@ -1624,8 +1603,7 @@ mod tests {
\n Projection: test2.a\
\n Filter: test2.a <= Int64(1)\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// single table predicate parts of ON condition should be pushed to both inputs
@@ -1655,8 +1633,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
+ "Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
@@ -1671,8 +1648,7 @@ mod tests {
\n Projection: test2.a, test2.b, test2.c\
\n Filter: test2.c > UInt32(4)\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// join filter should be completely removed after pushdown
@@ -1701,8 +1677,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\
+ "Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
@@ -1717,8 +1692,7 @@ mod tests {
\n Projection: test2.a, test2.b, test2.c\
\n Filter: test2.c > UInt32(4)\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// predicate on join key in filter expression should be pushed down to both inputs
@@ -1745,8 +1719,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\
+ "Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\
\n Projection: test.a\
\n TableScan: test\
\n Projection: test2.b\
@@ -1761,8 +1734,7 @@ mod tests {
\n Projection: test2.b\
\n Filter: test2.b > UInt32(1)\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// single table predicate parts of ON condition should be pushed to right input
@@ -1792,8 +1764,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
+ "Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
@@ -1807,8 +1778,7 @@ mod tests {
\n Projection: test2.a, test2.b, test2.c\
\n Filter: test2.c > UInt32(4)\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// single table predicate parts of ON condition should be pushed to left input
@@ -1838,8 +1808,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
+ "Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
@@ -1853,8 +1822,7 @@ mod tests {
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// single table predicate parts of ON condition should not be pushed
@@ -1884,8 +1852,7 @@ mod tests {
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
- "\
- Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
+ "Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\
\n Projection: test.a, test.b, test.c\
\n TableScan: test\
\n Projection: test2.a, test2.b, test2.c\
@@ -1893,8 +1860,7 @@ mod tests {
);
let expected = &format!("{:?}", plan);
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
struct PushDownProvider {
@@ -1961,8 +1927,7 @@ mod tests {
let expected = "\
TableScan: test, full_filters=[a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -1973,8 +1938,7 @@ mod tests {
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test, partial_filters=[a = Int64(1)]";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -1982,7 +1946,9 @@ mod tests {
let plan =
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
- let optimised_plan = optimize_plan(&plan);
+ let optimised_plan = PushDownFilter::new()
+ .optimize(&plan, &mut OptimizerConfig::new())
+ .expect("failed to optimize plan");
let expected = "\
Filter: a = Int64(1)\
@@ -1990,8 +1956,7 @@ mod tests {
// Optimizing the same plan multiple times should produce the same plan
// each time.
- assert_optimized_plan_eq(&optimised_plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&optimised_plan, expected)
}
#[test]
@@ -2002,8 +1967,7 @@ mod tests {
let expected = "\
Filter: a = Int64(1)\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -2028,13 +1992,11 @@ mod tests {
.project(vec![col("a"), col("b")])?
.build()?;
- let expected ="Projection: a, b\
+ let expected = "Projection: a, b\
\n Filter: a = Int64(10) AND b > Int64(11)\
\n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -2051,11 +2013,9 @@ mod tests {
// filter on col b
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: b > Int64(10) AND test.c > Int64(10)\
+ "Filter: b > Int64(10) AND test.c > Int64(10)\
\n Projection: test.a AS b, test.c\
- \n TableScan: test\
- "
+ \n TableScan: test"
);
// rewrite filter col b to test.a
@@ -2065,9 +2025,7 @@ mod tests {
\n TableScan: test\
";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -2085,8 +2043,7 @@ mod tests {
// filter on col b
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: b > Int64(10) AND test.c > Int64(10)\
+ "Filter: b > Int64(10) AND test.c > Int64(10)\
\n Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
@@ -2101,9 +2058,7 @@ mod tests {
\n TableScan: test\
";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -2117,8 +2072,7 @@ mod tests {
// filter on col b and d
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: b > Int64(10) AND d > Int64(10)\
+ "Filter: b > Int64(10) AND d > Int64(10)\
\n Projection: test.a AS b, test.c AS d\
\n TableScan: test\
"
@@ -2131,9 +2085,7 @@ mod tests {
\n TableScan: test\
";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
/// predicate on join key in filter expression should be pushed down to both inputs
@@ -2159,8 +2111,7 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
- "\
- Inner Join: c = d Filter: c > UInt32(1)\
+ "Inner Join: c = d Filter: c > UInt32(1)\
\n Projection: test.a AS c\
\n TableScan: test\
\n Projection: test2.b AS d\
@@ -2176,8 +2127,7 @@ mod tests {
\n Projection: test2.b AS d\
\n Filter: test2.b > UInt32(1)\
\n TableScan: test2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -2195,8 +2145,7 @@ mod tests {
// filter on col b
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
+ "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
"
@@ -2209,9 +2158,7 @@ mod tests {
\n TableScan: test\
";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -2230,8 +2177,7 @@ mod tests {
// filter on col b
assert_eq!(
format!("{:?}", plan),
- "\
- Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
+ "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\
\n Projection: b, test.c\
\n Projection: test.a AS b, test.c\
\n TableScan: test\
@@ -2246,9 +2192,7 @@ mod tests {
\n TableScan: test\
";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -2285,9 +2229,7 @@ mod tests {
\n Projection: sq.c\
\n TableScan: sq\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected_after);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected_after)
}
#[test]
@@ -2312,15 +2254,13 @@ mod tests {
// Ensure that the predicate without any columns (0 = 1) is
// still there.
let expected_after = "Projection: b.a\
- \n Filter: b.a = Int64(1)\
- \n SubqueryAlias: b\
- \n Projection: b.a\
- \n SubqueryAlias: b\
- \n Projection: Int64(0) AS a\
+ \n SubqueryAlias: b\
+ \n Projection: b.a\
+ \n SubqueryAlias: b\
+ \n Projection: Int64(0) AS a\
+ \n Filter: Int64(0) = Int64(1)\
\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected_after);
-
- Ok(())
+ assert_optimized_plan_eq(&plan, expected_after)
}
#[test]
@@ -2351,7 +2291,13 @@ mod tests {
\n TableScan: test\
\n Projection: test1.a AS d, test1.a AS e\
\n TableScan: test1";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_optimized_plan_eq(&plan, expected)?;
+
+ // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
+ // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
+ let optimized_plan = PushDownFilter::new()
+ .optimize(&plan, &mut OptimizerConfig::new())
+ .expect("failed to optimize plan");
+ assert_optimized_plan_eq(&optimized_plan, expected)
}
}
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index c4911439c..457ea833e 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -274,12 +274,12 @@ fn join_keys_in_subquery_alias() {
let plan = test_sql(sql).unwrap();
let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\
\n Inner Join: a.col_int32 = b.key\
- \n Filter: a.col_int32 IS NOT NULL\
- \n SubqueryAlias: a\
+ \n SubqueryAlias: a\
+ \n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\
- \n Filter: b.key IS NOT NULL\
- \n SubqueryAlias: b\
- \n Projection: test.col_int32 AS key\
+ \n SubqueryAlias: b\
+ \n Projection: test.col_int32 AS key\
+ \n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{:?}", plan));
}
@@ -288,20 +288,19 @@ fn join_keys_in_subquery_alias() {
fn join_keys_in_subquery_alias_1() {
let sql = "SELECT * FROM test AS A, ( SELECT test.col_int32 AS key FROM test JOIN test AS C on test.col_int32 = C.col_int32 ) AS B where A.col_int32 = B.key;";
let plan = test_sql(sql).unwrap();
- let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\
+ let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\
\n Inner Join: a.col_int32 = b.key\
- \n Filter: a.col_int32 IS NOT NULL\
- \n SubqueryAlias: a\
+ \n SubqueryAlias: a\
+ \n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\
- \n Filter: b.key IS NOT NULL\
- \n SubqueryAlias: b\
- \n Projection: test.col_int32 AS key\
- \n Inner Join: test.col_int32 = c.col_int32\
+ \n SubqueryAlias: b\
+ \n Projection: test.col_int32 AS key\
+ \n Inner Join: test.col_int32 = c.col_int32\
+ \n Filter: test.col_int32 IS NOT NULL\
+ \n TableScan: test projection=[col_int32]\
+ \n SubqueryAlias: c\
\n Filter: test.col_int32 IS NOT NULL\
- \n TableScan: test projection=[col_int32]\
- \n Filter: c.col_int32 IS NOT NULL\
- \n SubqueryAlias: c\
- \n TableScan: test projection=[col_int32]";
+ \n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{:?}", plan));
}