You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/20 18:26:34 UTC
[arrow-datafusion] branch master updated: optimizer: remove recursion in optimizer rules (#4650)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 27921135e optimizer: remove recursion in optimizer rules (#4650)
27921135e is described below
commit 27921135e4ff4b644251db6ab42f1a25bd6523cb
Author: jakevin <ja...@gmail.com>
AuthorDate: Wed Dec 21 02:26:27 2022 +0800
optimizer: remove recursion in optimizer rules (#4650)
---
benchmarks/expected-plans/q20.txt | 8 +-
datafusion/core/tests/sql/subqueries.rs | 20 +--
datafusion/expr/src/logical_plan/plan.rs | 8 +-
.../optimizer/src/decorrelate_where_exists.rs | 80 +++++-----
datafusion/optimizer/src/decorrelate_where_in.rs | 118 +++++++++------
datafusion/optimizer/src/eliminate_filter.rs | 167 +++++++++------------
datafusion/optimizer/src/eliminate_limit.rs | 39 +++--
datafusion/optimizer/src/eliminate_outer_join.rs | 38 +++--
datafusion/optimizer/src/filter_null_join_keys.rs | 61 +++-----
datafusion/optimizer/src/inline_table_scan.rs | 51 +++----
.../optimizer/src/propagate_empty_relation.rs | 158 +++++++++----------
.../optimizer/src/scalar_subquery_to_join.rs | 125 ++++++++++-----
datafusion/optimizer/src/test/mod.rs | 62 ++++++--
13 files changed, 479 insertions(+), 456 deletions(-)
diff --git a/benchmarks/expected-plans/q20.txt b/benchmarks/expected-plans/q20.txt
index 1266622ea..b2676f61f 100644
--- a/benchmarks/expected-plans/q20.txt
+++ b/benchmarks/expected-plans/q20.txt
@@ -1,17 +1,17 @@
Sort: supplier.s_name ASC NULLS LAST
Projection: supplier.s_name, supplier.s_address
- LeftSemi Join: supplier.s_suppkey = __sq_2.ps_suppkey
+ LeftSemi Join: supplier.s_suppkey = __sq_1.ps_suppkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey]
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name]
- SubqueryAlias: __sq_2
+ SubqueryAlias: __sq_1
Projection: partsupp.ps_suppkey AS ps_suppkey
Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey
- LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
+ LeftSemi Join: partsupp.ps_partkey = __sq_2.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
- SubqueryAlias: __sq_1
+ SubqueryAlias: __sq_2
Projection: part.p_partkey AS p_partkey
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name]
diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
index e6c98edf5..d221ddfe2 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -52,16 +52,16 @@ where c_acctbal < (
let actual = format!("{}", plan.display_indent());
let expected = "Sort: customer.c_custkey ASC NULLS LAST\
\n Projection: customer.c_custkey\
- \n Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __sq_2.__value\
- \n Inner Join: customer.c_custkey = __sq_2.o_custkey\
+ \n Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __sq_1.__value\
+ \n Inner Join: customer.c_custkey = __sq_1.o_custkey\
\n TableScan: customer projection=[c_custkey, c_acctbal]\
- \n SubqueryAlias: __sq_2\
+ \n SubqueryAlias: __sq_1\
\n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]]\
- \n Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __sq_1.__value\
- \n Inner Join: orders.o_orderkey = __sq_1.l_orderkey\
+ \n Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __sq_2.__value\
+ \n Inner Join: orders.o_orderkey = __sq_2.l_orderkey\
\n TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]\
- \n SubqueryAlias: __sq_1\
+ \n SubqueryAlias: __sq_2\
\n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value\
\n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]]\
\n TableScan: lineitem projection=[l_orderkey, l_extendedprice]";
@@ -324,18 +324,18 @@ order by s_name;
let actual = format!("{}", plan.display_indent());
let expected = "Sort: supplier.s_name ASC NULLS LAST\
\n Projection: supplier.s_name, supplier.s_address\
- \n LeftSemi Join: supplier.s_suppkey = __sq_2.ps_suppkey\
+ \n LeftSemi Join: supplier.s_suppkey = __sq_1.ps_suppkey\
\n Inner Join: supplier.s_nationkey = nation.n_nationkey\
\n TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey]\
\n Filter: nation.n_name = Utf8(\"CANADA\")\
\n TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8(\"CANADA\")]\
- \n SubqueryAlias: __sq_2\
+ \n SubqueryAlias: __sq_1\
\n Projection: partsupp.ps_suppkey AS ps_suppkey\
\n Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value\
\n Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey\
- \n LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey\
+ \n LeftSemi Join: partsupp.ps_partkey = __sq_2.p_partkey\
\n TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]\
- \n SubqueryAlias: __sq_1\
+ \n SubqueryAlias: __sq_2\
\n Projection: part.p_partkey AS p_partkey\
\n Filter: part.p_name LIKE Utf8(\"forest%\")\
\n TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8(\"forest%\")]\
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 14dfe7143..9d7fdf8f0 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1329,12 +1329,16 @@ impl SubqueryAlias {
/// If the value of `<predicate>` is true, the input row is passed to
/// the output. If the value of `<predicate>` is false, the row is
/// discarded.
+///
+/// Filter should not be created directly but instead use `try_new()`
+/// and that these fields are only pub to support pattern matching
#[derive(Clone)]
+#[non_exhaustive]
pub struct Filter {
/// The predicate expression, which must have Boolean type.
- predicate: Expr,
+ pub predicate: Expr,
/// The incoming logical plan
- input: Arc<LogicalPlan>,
+ pub input: Arc<LogicalPlan>,
}
impl Filter {
diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs
index f1addf651..50bbf6bb5 100644
--- a/datafusion/optimizer/src/decorrelate_where_exists.rs
+++ b/datafusion/optimizer/src/decorrelate_where_exists.rs
@@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.
+use crate::optimizer::ApplyOrder;
use crate::utils::{
conjunction, exprs_to_join_cols, find_join_exprs, split_conjunction,
verify_not_disjunction,
};
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{context, Result};
use datafusion_expr::{
logical_plan::{Filter, JoinType, Subquery},
@@ -81,27 +82,15 @@ impl OptimizerRule for DecorrelateWhereExists {
) -> Result<Option<LogicalPlan>> {
match plan {
LogicalPlan::Filter(filter) => {
- let predicate = filter.predicate();
- let filter_input = filter.input().as_ref();
-
- // Apply optimizer rule to current input
- let optimized_input = self
- .try_optimize(filter_input, config)?
- .unwrap_or_else(|| filter_input.clone());
-
let (subqueries, other_exprs) =
- self.extract_subquery_exprs(predicate, config)?;
- let optimized_plan = LogicalPlan::Filter(Filter::try_new(
- predicate.clone(),
- Arc::new(optimized_input),
- )?);
+ self.extract_subquery_exprs(filter.predicate(), config)?;
if subqueries.is_empty() {
// regular filter, no subquery exists clause here
- return Ok(Some(optimized_plan));
+ return Ok(None);
}
// iterate through all exists clauses in predicate, turning each into a join
- let mut cur_input = filter_input.clone();
+ let mut cur_input = filter.input().as_ref().clone();
for subquery in subqueries {
if let Some(x) = optimize_exists(&subquery, &cur_input, &other_exprs)?
{
@@ -112,16 +101,17 @@ impl OptimizerRule for DecorrelateWhereExists {
}
Ok(Some(cur_input))
}
- _ => {
- // Apply the optimization to all inputs of the plan
- Ok(Some(utils::optimize_children(self, plan, config)?))
- }
+ _ => Ok(None),
}
}
fn name(&self) -> &str {
"decorrelate_where_exists"
}
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::TopDown)
+ }
}
/// Takes a query like:
@@ -226,6 +216,15 @@ mod tests {
};
use std::ops::Add;
+ fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereExists::new()),
+ plan,
+ expected,
+ );
+ Ok(())
+ }
+
/// Test for multiple exists subqueries in the same filter expression
#[test]
fn multiple_subqueries() -> Result<()> {
@@ -248,8 +247,7 @@ mod tests {
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
- assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
- Ok(())
+ assert_plan_eq(&plan, expected)
}
/// Test recursive correlated subqueries
@@ -284,8 +282,7 @@ mod tests {
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#;
- assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
- Ok(())
+ assert_plan_eq(&plan, expected)
}
/// Test for correlated exists subquery filter with additional subquery filters
@@ -313,8 +310,7 @@ mod tests {
Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
- assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
- Ok(())
+ assert_plan_eq(&plan, expected)
}
/// Test for correlated exists subquery with no columns in schema
@@ -332,8 +328,7 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan);
- Ok(())
+ assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
}
/// Test for exists subquery with both columns in schema
@@ -351,8 +346,7 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan);
- Ok(())
+ assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
}
/// Test for correlated exists subquery not equal
@@ -370,8 +364,7 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan);
- Ok(())
+ assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
}
/// Test for correlated exists subquery less than
@@ -391,7 +384,7 @@ mod tests {
let expected = r#"can't optimize < column comparison"#;
- assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected);
Ok(())
}
@@ -416,7 +409,7 @@ mod tests {
let expected = r#"Optimizing disjunctions not supported!"#;
- assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected);
Ok(())
}
@@ -434,8 +427,7 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan);
- Ok(())
+ assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
}
/// Test for correlated exists expressions
@@ -459,8 +451,7 @@ mod tests {
TableScan: customer [c_custkey:Int64, c_name:Utf8]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
- assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
- Ok(())
+ assert_plan_eq(&plan, expected)
}
/// Test for correlated exists subquery filter with additional filters
@@ -483,8 +474,7 @@ mod tests {
TableScan: customer [c_custkey:Int64, c_name:Utf8]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
- assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
- Ok(())
+ assert_plan_eq(&plan, expected)
}
/// Test for correlated exists subquery filter with disjustions
@@ -511,8 +501,7 @@ mod tests {
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
- assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
- Ok(())
+ assert_plan_eq(&plan, expected)
}
/// Test for correlated EXISTS subquery filter
@@ -535,8 +524,7 @@ mod tests {
TableScan: test [a:UInt32, b:UInt32, c:UInt32]
TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#;
- assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
- Ok(())
+ assert_plan_eq(&plan, expected)
}
/// Test for single exists subquery filter
@@ -550,7 +538,7 @@ mod tests {
let expected = "cannot optimize non-correlated subquery";
- assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected);
Ok(())
}
@@ -565,7 +553,7 @@ mod tests {
let expected = "cannot optimize non-correlated subquery";
- assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected);
Ok(())
}
}
diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs
index d2555ea5c..91dd9c550 100644
--- a/datafusion/optimizer/src/decorrelate_where_in.rs
+++ b/datafusion/optimizer/src/decorrelate_where_in.rs
@@ -15,13 +15,14 @@
// specific language governing permissions and limitations
// under the License.
+use crate::optimizer::ApplyOrder;
use crate::utils::{
alias_cols, conjunction, exprs_to_join_cols, find_join_exprs, merge_cols,
only_or_err, split_conjunction, swap_table, verify_not_disjunction,
};
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{context, Result};
-use datafusion_expr::logical_plan::{Filter, JoinType, Projection, Subquery};
+use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
use log::debug;
use std::sync::Arc;
@@ -85,43 +86,32 @@ impl OptimizerRule for DecorrelateWhereIn {
) -> Result<Option<LogicalPlan>> {
match plan {
LogicalPlan::Filter(filter) => {
- let predicate = filter.predicate();
- let filter_input = filter.input().as_ref();
-
- // Apply optimizer rule to current input
- let optimized_input = self
- .try_optimize(filter_input, config)?
- .unwrap_or_else(|| filter_input.clone());
-
let (subqueries, other_exprs) =
- self.extract_subquery_exprs(predicate, config)?;
- let optimized_plan = LogicalPlan::Filter(Filter::try_new(
- predicate.clone(),
- Arc::new(optimized_input),
- )?);
+ self.extract_subquery_exprs(filter.predicate(), config)?;
if subqueries.is_empty() {
// regular filter, no subquery exists clause here
- return Ok(Some(optimized_plan));
+ return Ok(None);
}
// iterate through all exists clauses in predicate, turning each into a join
- let mut cur_input = filter_input.clone();
+ let mut cur_input = filter.input().as_ref().clone();
for subquery in subqueries {
cur_input =
optimize_where_in(&subquery, &cur_input, &other_exprs, config)?;
}
Ok(Some(cur_input))
}
- _ => {
- // Apply the optimization to all inputs of the plan
- Ok(Some(utils::optimize_children(self, plan, config)?))
- }
+ _ => Ok(None),
}
}
fn name(&self) -> &str {
"decorrelate_where_in"
}
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::TopDown)
+ }
}
fn optimize_where_in(
@@ -268,7 +258,11 @@ mod tests {
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n SubqueryAlias: __sq_2 [o_custkey:Int64]\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -299,17 +293,21 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: customer.c_custkey = __sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __sq_2 [o_custkey:Int64]\
+ \n SubqueryAlias: __sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
- \n LeftSemi Join: orders.o_orderkey = __sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n LeftSemi Join: orders.o_orderkey = __sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n SubqueryAlias: __sq_1 [l_orderkey:Int64]\
+ \n SubqueryAlias: __sq_2 [l_orderkey:Int64]\
\n Projection: lineitem.l_orderkey AS l_orderkey [l_orderkey:Int64]\
\n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -340,7 +338,11 @@ mod tests {
\n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -368,7 +370,11 @@ mod tests {
\n Filter: customer.c_custkey = customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -395,7 +401,11 @@ mod tests {
\n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -421,7 +431,11 @@ mod tests {
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -442,7 +456,7 @@ mod tests {
// can't optimize on arbitrary expressions (yet)
assert_optimizer_err(
- &DecorrelateWhereIn::new(),
+ Arc::new(DecorrelateWhereIn::new()),
&plan,
"column correlation not found",
);
@@ -469,7 +483,7 @@ mod tests {
.build()?;
assert_optimizer_err(
- &DecorrelateWhereIn::new(),
+ Arc::new(DecorrelateWhereIn::new()),
&plan,
"Optimizing disjunctions not supported!",
);
@@ -492,7 +506,7 @@ mod tests {
// Maybe okay if the table only has a single column?
assert_optimizer_err(
- &DecorrelateWhereIn::new(),
+ Arc::new(DecorrelateWhereIn::new()),
&plan,
"a projection is required",
);
@@ -516,7 +530,7 @@ mod tests {
// TODO: support join on expression
assert_optimizer_err(
- &DecorrelateWhereIn::new(),
+ Arc::new(DecorrelateWhereIn::new()),
&plan,
"column comparison required",
);
@@ -540,7 +554,7 @@ mod tests {
// TODO: support join on expressions?
assert_optimizer_err(
- &DecorrelateWhereIn::new(),
+ Arc::new(DecorrelateWhereIn::new()),
&plan,
"single column projection required",
);
@@ -566,7 +580,7 @@ mod tests {
.build()?;
assert_optimizer_err(
- &DecorrelateWhereIn::new(),
+ Arc::new(DecorrelateWhereIn::new()),
&plan,
"single expression projection required",
);
@@ -599,7 +613,11 @@ mod tests {
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -630,7 +648,11 @@ mod tests {
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -656,7 +678,11 @@ mod tests {
\n Projection: sq.c AS c, sq.a AS a [c:UInt32, a:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -676,7 +702,11 @@ mod tests {
\n Projection: sq.c AS c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -696,7 +726,11 @@ mod tests {
\n Projection: sq.c AS c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
}
diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs
index 7636a6a9f..e4fc80341 100644
--- a/datafusion/optimizer/src/eliminate_filter.rs
+++ b/datafusion/optimizer/src/eliminate_filter.rs
@@ -18,13 +18,14 @@
//! Optimizer rule to replace `where false` on a plan with an empty relation.
//! This saves time in planning and executing the query.
//! Note that this rule should be applied after simplify expressions optimizer rule.
+use crate::optimizer::ApplyOrder;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
logical_plan::{EmptyRelation, LogicalPlan},
- Expr,
+ Expr, Filter,
};
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::{OptimizerConfig, OptimizerRule};
/// Optimization rule that elimanate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation]
#[derive(Default)]
@@ -41,139 +42,119 @@ impl OptimizerRule for EliminateFilter {
fn try_optimize(
&self,
plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
+ _config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
- let predicate_and_input = match plan {
- LogicalPlan::Filter(filter) => match filter.predicate() {
- Expr::Literal(ScalarValue::Boolean(Some(v))) => {
- Some((*v, filter.input()))
+ match plan {
+ LogicalPlan::Filter(Filter {
+ predicate: Expr::Literal(ScalarValue::Boolean(Some(v))),
+ input,
+ ..
+ }) => {
+ match *v {
+ // input also can be filter, apply again
+ true => Ok(Some(
+ self.try_optimize(input, _config)?
+ .unwrap_or_else(|| input.as_ref().clone()),
+ )),
+ false => Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: false,
+ schema: input.schema().clone(),
+ }))),
}
- _ => None,
- },
- _ => None,
- };
-
- match predicate_and_input {
- Some((true, input)) => self.try_optimize(input, config),
- Some((false, input)) => Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: input.schema().clone(),
- }))),
- None => {
- // Apply the optimization to all inputs of the plan
- Ok(Some(utils::optimize_children(self, plan, config)?))
}
+ _ => Ok(None),
}
}
fn name(&self) -> &str {
"eliminate_filter"
}
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::TopDown)
+ }
}
#[cfg(test)]
mod tests {
- use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, sum};
+ use crate::eliminate_filter::EliminateFilter;
+ use datafusion_common::{Result, ScalarValue};
+ use datafusion_expr::{
+ col, lit, logical_plan::builder::LogicalPlanBuilder, sum, Expr, LogicalPlan,
+ };
+ use std::sync::Arc;
- use crate::optimizer::OptimizerContext;
use crate::test::*;
- use super::*;
-
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
- let rule = EliminateFilter::new();
- let optimized_plan = rule
- .try_optimize(plan, &OptimizerContext::new())
- .unwrap()
- .expect("failed to optimize plan");
- let formatted_plan = format!("{:?}", optimized_plan);
- assert_eq!(formatted_plan, expected);
- assert_eq!(plan.schema(), optimized_plan.schema());
+ fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected)
}
#[test]
- fn filter_false() {
+ fn filter_false() -> Result<()> {
let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false)));
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
- .aggregate(vec![col("a")], vec![sum(col("b"))])
- .unwrap()
- .filter(filter_expr)
- .unwrap()
- .build()
- .unwrap();
+ .aggregate(vec![col("a")], vec![sum(col("b"))])?
+ .filter(filter_expr)?
+ .build()?;
// No aggregate / scan / limit
let expected = "EmptyRelation";
- assert_optimized_plan_eq(&plan, expected);
+ assert_eq(&plan, expected)
}
#[test]
- fn filter_false_nested() {
+ fn filter_false_nested() -> Result<()> {
let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false)));
- let table_scan = test_table_scan().unwrap();
+ let table_scan = test_table_scan()?;
let plan1 = LogicalPlanBuilder::from(table_scan.clone())
- .aggregate(vec![col("a")], vec![sum(col("b"))])
- .unwrap()
- .build()
- .unwrap();
+ .aggregate(vec![col("a")], vec![sum(col("b"))])?
+ .build()?;
let plan = LogicalPlanBuilder::from(table_scan)
- .aggregate(vec![col("a")], vec![sum(col("b"))])
- .unwrap()
- .filter(filter_expr)
- .unwrap()
- .union(plan1)
- .unwrap()
- .build()
- .unwrap();
+ .aggregate(vec![col("a")], vec![sum(col("b"))])?
+ .filter(filter_expr)?
+ .union(plan1)?
+ .build()?;
// Left side is removed
let expected = "Union\
\n EmptyRelation\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
+ assert_eq(&plan, expected)
}
#[test]
- fn filter_true() {
+ fn filter_true() -> Result<()> {
let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(true)));
- let table_scan = test_table_scan().unwrap();
+ let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
- .aggregate(vec![col("a")], vec![sum(col("b"))])
- .unwrap()
- .filter(filter_expr)
- .unwrap()
- .build()
- .unwrap();
+ .aggregate(vec![col("a")], vec![sum(col("b"))])?
+ .filter(filter_expr)?
+ .build()?;
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
+ assert_eq(&plan, expected)
}
#[test]
- fn filter_true_nested() {
+ fn filter_true_nested() -> Result<()> {
let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(true)));
- let table_scan = test_table_scan().unwrap();
+ let table_scan = test_table_scan()?;
let plan1 = LogicalPlanBuilder::from(table_scan.clone())
- .aggregate(vec![col("a")], vec![sum(col("b"))])
- .unwrap()
- .build()
- .unwrap();
+ .aggregate(vec![col("a")], vec![sum(col("b"))])?
+ .build()?;
let plan = LogicalPlanBuilder::from(table_scan)
- .aggregate(vec![col("a")], vec![sum(col("b"))])
- .unwrap()
- .filter(filter_expr)
- .unwrap()
- .union(plan1)
- .unwrap()
- .build()
- .unwrap();
+ .aggregate(vec![col("a")], vec![sum(col("b"))])?
+ .filter(filter_expr)?
+ .union(plan1)?
+ .build()?;
// Filter is removed
let expected = "Union\
@@ -181,35 +162,29 @@ mod tests {
\n TableScan: test\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\
\n TableScan: test";
- assert_optimized_plan_eq(&plan, expected);
+ assert_eq(&plan, expected)
}
#[test]
- fn filter_from_subquery() {
+ fn filter_from_subquery() -> Result<()> {
// SELECT a FROM (SELECT a FROM test WHERE FALSE) WHERE TRUE
let false_filter = lit(false);
- let table_scan = test_table_scan().unwrap();
+ let table_scan = test_table_scan()?;
let plan1 = LogicalPlanBuilder::from(table_scan)
- .project(vec![col("a")])
- .unwrap()
- .filter(false_filter)
- .unwrap()
- .build()
- .unwrap();
+ .project(vec![col("a")])?
+ .filter(false_filter)?
+ .build()?;
let true_filter = lit(true);
let plan = LogicalPlanBuilder::from(plan1)
- .project(vec![col("a")])
- .unwrap()
- .filter(true_filter)
- .unwrap()
- .build()
- .unwrap();
+ .project(vec![col("a")])?
+ .filter(true_filter)?
+ .build()?;
// Filter is removed
let expected = "Projection: test.a\
\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected);
+ assert_eq(&plan, expected)
}
}
diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs
index 9e3cbf6fa..caea145dd 100644
--- a/datafusion/optimizer/src/eliminate_limit.rs
+++ b/datafusion/optimizer/src/eliminate_limit.rs
@@ -43,28 +43,25 @@ impl OptimizerRule for EliminateLimit {
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
- let limit = match plan {
- LogicalPlan::Limit(limit) => limit,
- _ => return Ok(None),
- };
-
- match limit.fetch {
- Some(fetch) => {
- if fetch == 0 {
- return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
- produce_one_row: false,
- schema: limit.input.schema().clone(),
- })));
+ if let LogicalPlan::Limit(limit) = plan {
+ match limit.fetch {
+ Some(fetch) => {
+ if fetch == 0 {
+ return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: false,
+ schema: limit.input.schema().clone(),
+ })));
+ }
}
- }
- None => {
- if limit.skip == 0 {
- let input = limit.input.as_ref();
- // input also can be Limit, so we should apply again.
- return Ok(Some(
- self.try_optimize(input, _config)?
- .unwrap_or_else(|| input.clone()),
- ));
+ None => {
+ if limit.skip == 0 {
+ let input = limit.input.as_ref();
+ // input also can be Limit, so we should apply again.
+ return Ok(Some(
+ self.try_optimize(input, _config)?
+ .unwrap_or_else(|| input.clone()),
+ ));
+ }
}
}
}
diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs
index cc535117d..8c02950d8 100644
--- a/datafusion/optimizer/src/eliminate_outer_join.rs
+++ b/datafusion/optimizer/src/eliminate_outer_join.rs
@@ -16,7 +16,7 @@
// under the License.
//! Optimizer rule to eliminate left/right/full join to inner join if possible.
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, Result};
use datafusion_expr::{
logical_plan::{Join, JoinType, LogicalPlan},
@@ -24,6 +24,7 @@ use datafusion_expr::{
};
use datafusion_expr::{Expr, Operator};
+use crate::optimizer::ApplyOrder;
use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
use std::sync::Arc;
@@ -64,7 +65,7 @@ impl OptimizerRule for EliminateOuterJoin {
fn try_optimize(
&self,
plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
+ _config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
match plan {
LogicalPlan::Filter(filter) => match filter.input().as_ref() {
@@ -109,17 +110,21 @@ impl OptimizerRule for EliminateOuterJoin {
null_equals_null: join.null_equals_null,
});
let new_plan = from_plan(plan, &plan.expressions(), &[new_join])?;
- Ok(Some(utils::optimize_children(self, &new_plan, config)?))
+ Ok(Some(new_plan))
}
- _ => Ok(Some(utils::optimize_children(self, plan, config)?)),
+ _ => Ok(None),
},
- _ => Ok(Some(utils::optimize_children(self, plan, config)?)),
+ _ => Ok(None),
}
}
fn name(&self) -> &str {
"eliminate_outer_join"
}
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::TopDown)
+ }
}
pub fn eliminate_outer(
@@ -295,7 +300,6 @@ fn extract_non_nullable_columns(
#[cfg(test)]
mod tests {
use super::*;
- use crate::optimizer::OptimizerContext;
use crate::test::*;
use arrow::datatypes::DataType;
use datafusion_expr::{
@@ -305,16 +309,8 @@ mod tests {
Operator::{And, Or},
};
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
- let rule = EliminateOuterJoin::new();
- let optimized_plan = rule
- .try_optimize(plan, &OptimizerContext::new())
- .unwrap()
- .expect("failed to optimize plan");
- let formatted_plan = format!("{:?}", optimized_plan);
- assert_eq!(formatted_plan, expected);
- assert_eq!(plan.schema(), optimized_plan.schema());
- Ok(())
+ fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected)
}
#[test]
@@ -337,7 +333,7 @@ mod tests {
\n Left Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_eq(&plan, expected)
}
#[test]
@@ -360,7 +356,7 @@ mod tests {
\n Inner Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_eq(&plan, expected)
}
#[test]
@@ -387,7 +383,7 @@ mod tests {
\n Inner Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_eq(&plan, expected)
}
#[test]
@@ -414,7 +410,7 @@ mod tests {
\n Inner Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_eq(&plan, expected)
}
#[test]
@@ -441,6 +437,6 @@ mod tests {
\n Inner Join: t1.a = t2.a\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected)
+ assert_eq(&plan, expected)
}
}
diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs
index eea98ad1f..8a6c995de 100644
--- a/datafusion/optimizer/src/filter_null_join_keys.rs
+++ b/datafusion/optimizer/src/filter_null_join_keys.rs
@@ -20,7 +20,8 @@
//! and then insert an `IsNotNull` filter on the nullable side since null values
//! can never match.
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_expr::{
and, logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, LogicalPlan,
@@ -42,20 +43,11 @@ impl OptimizerRule for FilterNullJoinKeys {
fn try_optimize(
&self,
plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
+ _config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
match plan {
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
- // recurse down first and optimize inputs
let mut join = join.clone();
- join.left = Arc::new(
- self.try_optimize(&join.left, config)?
- .unwrap_or_else(|| join.left.as_ref().clone()),
- );
- join.right = Arc::new(
- self.try_optimize(&join.right, config)?
- .unwrap_or_else(|| join.right.as_ref().clone()),
- );
let left_schema = join.left.schema();
let right_schema = join.right.schema();
@@ -89,16 +81,17 @@ impl OptimizerRule for FilterNullJoinKeys {
}
Ok(Some(LogicalPlan::Join(join)))
}
- _ => {
- // Apply the optimization to all inputs of the plan
- Ok(Some(utils::optimize_children(self, plan, config)?))
- }
+ _ => Ok(None),
}
}
fn name(&self) -> &str {
Self::NAME
}
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::BottomUp)
+ }
}
fn create_not_null_predicate(filters: Vec<Expr>) -> Expr {
@@ -115,27 +108,15 @@ fn create_not_null_predicate(filters: Vec<Expr>) -> Expr {
#[cfg(test)]
mod tests {
+ use super::*;
+ use crate::test::assert_optimized_plan_eq;
use arrow::datatypes::{DataType, Field, Schema};
-
use datafusion_common::{Column, Result};
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{col, lit, logical_plan::JoinType, LogicalPlanBuilder};
- use crate::optimizer::OptimizerContext;
-
- use super::*;
-
- fn optimize_plan(plan: &LogicalPlan) -> LogicalPlan {
- let rule = FilterNullJoinKeys::default();
- rule.try_optimize(plan, &OptimizerContext::new())
- .unwrap()
- .expect("failed to optimize plan")
- }
-
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
- let optimized_plan = optimize_plan(plan);
- let formatted_plan = format!("{:?}", optimized_plan);
- assert_eq!(formatted_plan, expected);
+ fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected)
}
#[test]
@@ -146,8 +127,7 @@ mod tests {
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_eq(&plan, expected)
}
#[test]
@@ -158,8 +138,7 @@ mod tests {
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_eq(&plan, expected)
}
#[test]
@@ -196,8 +175,7 @@ mod tests {
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_eq(&plan, expected)
}
#[test]
@@ -218,8 +196,7 @@ mod tests {
\n Filter: t1.optional_id + UInt32(1) IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_eq(&plan, expected)
}
#[test]
@@ -240,8 +217,7 @@ mod tests {
\n TableScan: t1\
\n Filter: t2.optional_id + UInt32(1) IS NOT NULL\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_eq(&plan, expected)
}
#[test]
@@ -264,8 +240,7 @@ mod tests {
\n TableScan: t1\
\n Filter: t2.optional_id + UInt32(1) IS NOT NULL\
\n TableScan: t2";
- assert_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_eq(&plan, expected)
}
fn build_plan(
diff --git a/datafusion/optimizer/src/inline_table_scan.rs b/datafusion/optimizer/src/inline_table_scan.rs
index fe24e675d..1783cf0a2 100644
--- a/datafusion/optimizer/src/inline_table_scan.rs
+++ b/datafusion/optimizer/src/inline_table_scan.rs
@@ -18,7 +18,8 @@
//! Optimizer rule to replace TableScan references
//! such as DataFrames and Views and inlines the LogicalPlan
//! to support further optimization
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_expr::{logical_plan::LogicalPlan, Expr, LogicalPlanBuilder, TableScan};
@@ -38,7 +39,7 @@ impl OptimizerRule for InlineTableScan {
fn try_optimize(
&self,
plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
+ _config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
match plan {
// Match only on scans without filter / projection / fetch
@@ -51,29 +52,25 @@ impl OptimizerRule for InlineTableScan {
..
}) if filters.is_empty() => {
if let Some(sub_plan) = source.get_logical_plan() {
- // Recursively apply optimization
- let plan = utils::optimize_children(self, sub_plan, config)?;
- let plan = LogicalPlanBuilder::from(plan)
+ let plan = LogicalPlanBuilder::from(sub_plan.clone())
.project(vec![Expr::Wildcard])?
.alias(table_name)?;
Ok(Some(plan.build()?))
} else {
- // No plan available, return with table scan as is
- Ok(Some(plan.clone()))
+ Ok(None)
}
}
-
- // Rest: Recurse
- _ => {
- // apply the optimization to all inputs of the plan
- Ok(Some(utils::optimize_children(self, plan, config)?))
- }
+ _ => Ok(None),
}
}
fn name(&self) -> &str {
"inline_table_scan"
}
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::BottomUp)
+ }
}
#[cfg(test)]
@@ -83,8 +80,8 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, TableSource};
- use crate::optimizer::OptimizerContext;
- use crate::{inline_table_scan::InlineTableScan, OptimizerRule};
+ use crate::inline_table_scan::InlineTableScan;
+ use crate::test::assert_optimized_plan_eq;
pub struct RawTableSource {}
@@ -144,26 +141,18 @@ mod tests {
}
#[test]
- fn inline_table_scan() {
- let rule = InlineTableScan::new();
-
- let source = Arc::new(CustomSource::new());
-
- let scan = LogicalPlanBuilder::scan("x".to_string(), source, None).unwrap();
-
- let plan = scan.filter(col("x.a").eq(lit(1))).unwrap().build().unwrap();
-
- let optimized_plan = rule
- .try_optimize(&plan, &OptimizerContext::new())
- .unwrap()
- .expect("failed to optimize plan");
- let formatted_plan = format!("{:?}", optimized_plan);
+ fn inline_table_scan() -> datafusion_common::Result<()> {
+ let scan = LogicalPlanBuilder::scan(
+ "x".to_string(),
+ Arc::new(CustomSource::new()),
+ None,
+ )?;
+ let plan = scan.filter(col("x.a").eq(lit(1)))?.build()?;
let expected = "Filter: x.a = Int32(1)\
\n SubqueryAlias: x\
\n Projection: y.a\
\n TableScan: y";
- assert_eq!(formatted_plan, expected);
- assert_eq!(plan.schema(), optimized_plan.schema());
+ assert_optimized_plan_eq(Arc::new(InlineTableScan::new()), &plan, expected)
}
}
diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs
index 7ef769e21..e3a86381f 100644
--- a/datafusion/optimizer/src/propagate_empty_relation.rs
+++ b/datafusion/optimizer/src/propagate_empty_relation.rs
@@ -20,7 +20,8 @@ use datafusion_expr::logical_plan::LogicalPlan;
use datafusion_expr::{EmptyRelation, JoinType, Projection, Union};
use std::sync::Arc;
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
/// Optimization rule that bottom-up to eliminate plan by propagating empty_relation.
#[derive(Default)]
@@ -37,32 +38,28 @@ impl OptimizerRule for PropagateEmptyRelation {
fn try_optimize(
&self,
plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
+ _config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
- // optimize child plans first
- let optimized_children_plan = utils::optimize_children(self, plan, config)?;
- match &optimized_children_plan {
- LogicalPlan::EmptyRelation(_) => Ok(Some(optimized_children_plan)),
+ match plan {
+ LogicalPlan::EmptyRelation(_) => {}
LogicalPlan::Projection(_)
| LogicalPlan::Filter(_)
| LogicalPlan::Window(_)
| LogicalPlan::Sort(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Repartition(_)
- | LogicalPlan::Limit(_) => match empty_child(&optimized_children_plan)? {
- Some(empty) => Ok(Some(empty)),
- None => Ok(Some(optimized_children_plan)),
- },
+ | LogicalPlan::Limit(_) => {
+ if let Some(empty) = empty_child(plan)? {
+ return Ok(Some(empty));
+ }
+ }
LogicalPlan::CrossJoin(_) => {
- let (left_empty, right_empty) =
- binary_plan_children_is_empty(&optimized_children_plan)?;
+ let (left_empty, right_empty) = binary_plan_children_is_empty(plan)?;
if left_empty || right_empty {
- Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
+ return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
- schema: optimized_children_plan.schema().clone(),
- })))
- } else {
- Ok(Some(optimized_children_plan))
+ schema: plan.schema().clone(),
+ })));
}
}
LogicalPlan::Join(join) => {
@@ -79,18 +76,13 @@ impl OptimizerRule for PropagateEmptyRelation {
// For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side
// columns + left side columns replaced with null values.
if join.join_type == JoinType::Inner {
- let (left_empty, right_empty) =
- binary_plan_children_is_empty(&optimized_children_plan)?;
+ let (left_empty, right_empty) = binary_plan_children_is_empty(plan)?;
if left_empty || right_empty {
- Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
+ return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
- schema: optimized_children_plan.schema().clone(),
- })))
- } else {
- Ok(Some(optimized_children_plan))
+ schema: plan.schema().clone(),
+ })));
}
- } else {
- Ok(Some(optimized_children_plan))
}
}
LogicalPlan::Union(union) => {
@@ -105,46 +97,50 @@ impl OptimizerRule for PropagateEmptyRelation {
.collect::<Vec<_>>();
if new_inputs.len() == union.inputs.len() {
- Ok(Some(optimized_children_plan))
+ return Ok(None);
} else if new_inputs.is_empty() {
- Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
+ return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
- schema: optimized_children_plan.schema().clone(),
- })))
+ schema: plan.schema().clone(),
+ })));
} else if new_inputs.len() == 1 {
let child = (**(union.inputs.get(0).unwrap())).clone();
- if child.schema().eq(optimized_children_plan.schema()) {
- Ok(Some(child))
+ if child.schema().eq(plan.schema()) {
+ return Ok(Some(child));
} else {
- Ok(Some(LogicalPlan::Projection(Projection::new_from_schema(
- Arc::new(child),
- optimized_children_plan.schema().clone(),
- ))))
+ return Ok(Some(LogicalPlan::Projection(
+ Projection::new_from_schema(
+ Arc::new(child),
+ plan.schema().clone(),
+ ),
+ )));
}
} else {
- Ok(Some(LogicalPlan::Union(Union {
+ return Ok(Some(LogicalPlan::Union(Union {
inputs: new_inputs,
schema: union.schema.clone(),
- })))
+ })));
}
}
LogicalPlan::Aggregate(agg) => {
if !agg.group_expr.is_empty() {
- match empty_child(&optimized_children_plan)? {
- Some(empty) => Ok(Some(empty)),
- None => Ok(Some(optimized_children_plan)),
+ if let Some(empty) = empty_child(plan)? {
+ return Ok(Some(empty));
}
- } else {
- Ok(Some(optimized_children_plan))
}
}
- _ => Ok(Some(optimized_children_plan)),
+ _ => {}
}
+ Ok(None)
}
fn name(&self) -> &str {
"propagate_empty_relation"
}
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::BottomUp)
+ }
}
fn binary_plan_children_is_empty(plan: &LogicalPlan) -> Result<(bool, bool)> {
@@ -202,7 +198,10 @@ fn empty_child(plan: &LogicalPlan) -> Result<Option<LogicalPlan>> {
#[cfg(test)]
mod tests {
use crate::eliminate_filter::EliminateFilter;
- use crate::test::{test_table_scan, test_table_scan_with_name};
+ use crate::optimizer::Optimizer;
+ use crate::test::{
+ assert_optimized_plan_eq, test_table_scan, test_table_scan_with_name,
+ };
use crate::OptimizerContext;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Column, ScalarValue};
@@ -214,29 +213,29 @@ mod tests {
use super::*;
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
- let rule = PropagateEmptyRelation::new();
- let optimized_plan = rule
- .try_optimize(plan, &OptimizerContext::new())
- .unwrap()
- .expect("failed to optimize plan");
- let formatted_plan = format!("{:?}", optimized_plan);
- assert_eq!(formatted_plan, expected);
- assert_eq!(plan.schema(), optimized_plan.schema());
+ fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected)
}
- fn assert_together_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
- let optimize_one = EliminateFilter::new()
- .try_optimize(plan, &OptimizerContext::new())
- .unwrap()
- .expect("failed to optimize plan");
- let optimize_two = PropagateEmptyRelation::new()
- .try_optimize(&optimize_one, &OptimizerContext::new())
- .unwrap()
+ fn assert_together_optimized_plan_eq(
+ plan: &LogicalPlan,
+ expected: &str,
+ ) -> Result<()> {
+ fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
+ let optimizer = Optimizer::with_rules(vec![
+ Arc::new(EliminateFilter::new()),
+ Arc::new(PropagateEmptyRelation::new()),
+ ]);
+ let config = &mut OptimizerContext::new()
+ .with_max_passes(1)
+ .with_skip_failing_rules(false);
+ let optimized_plan = optimizer
+ .optimize(plan, config, observe)
.expect("failed to optimize plan");
- let formatted_plan = format!("{:?}", optimize_two);
+ let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
- assert_eq!(plan.schema(), optimize_two.schema());
+ assert_eq!(plan.schema(), optimized_plan.schema());
+ Ok(())
}
#[test]
@@ -248,9 +247,7 @@ mod tests {
.build()?;
let expected = "EmptyRelation";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_eq(&plan, expected)
}
#[test]
@@ -273,9 +270,7 @@ mod tests {
.build()?;
let expected = "EmptyRelation";
- assert_together_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_together_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -289,9 +284,7 @@ mod tests {
let expected = "Projection: a, b, c\
\n TableScan: test";
- assert_together_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_together_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -316,9 +309,7 @@ mod tests {
let expected = "Union\
\n TableScan: test1\
\n TableScan: test4";
- assert_together_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_together_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -343,9 +334,7 @@ mod tests {
.build()?;
let expected = "EmptyRelation";
- assert_together_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_together_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -372,9 +361,7 @@ mod tests {
let expected = "Union\
\n TableScan: test2\
\n TableScan: test3";
- assert_together_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_together_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -388,8 +375,7 @@ mod tests {
let expected = "Projection: a, b, c\
\n TableScan: test";
- assert_together_optimized_plan_eq(&plan, expected);
- Ok(())
+ assert_together_optimized_plan_eq(&plan, expected)
}
#[test]
@@ -404,8 +390,6 @@ mod tests {
.build()?;
let expected = "EmptyRelation";
- assert_together_optimized_plan_eq(&plan, expected);
-
- Ok(())
+ assert_together_optimized_plan_eq(&plan, expected)
}
}
diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs
index 0a6110541..8e7610bcc 100644
--- a/datafusion/optimizer/src/scalar_subquery_to_join.rs
+++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs
@@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.
+use crate::optimizer::ApplyOrder;
use crate::utils::{
conjunction, exprs_to_join_cols, find_join_exprs, only_or_err, split_conjunction,
verify_not_disjunction,
};
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{context, plan_err, Column, Result};
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::logical_plan::{Filter, JoinType, Limit, Subquery};
@@ -97,20 +98,12 @@ impl OptimizerRule for ScalarSubqueryToJoin {
) -> Result<Option<LogicalPlan>> {
match plan {
LogicalPlan::Filter(filter) => {
- // Apply optimizer rule to current input
- let optimized_input = self
- .try_optimize(filter.input(), config)?
- .unwrap_or_else(|| filter.input().as_ref().clone());
-
let (subqueries, other_exprs) =
self.extract_subquery_exprs(filter.predicate(), config)?;
if subqueries.is_empty() {
// regular filter, no subquery exists clause here
- return Ok(Some(LogicalPlan::Filter(Filter::try_new(
- filter.predicate().clone(),
- Arc::new(optimized_input),
- )?)));
+ return Ok(None);
}
// iterate through all subqueries in predicate, turning each into a join
@@ -122,24 +115,22 @@ impl OptimizerRule for ScalarSubqueryToJoin {
cur_input = optimized_subquery;
} else {
// if we can't handle all of the subqueries then bail for now
- return Ok(Some(LogicalPlan::Filter(Filter::try_new(
- filter.predicate().clone(),
- Arc::new(optimized_input),
- )?)));
+ return Ok(None);
}
}
Ok(Some(cur_input))
}
- _ => {
- // Apply the optimization to all inputs of the plan
- Ok(Some(utils::optimize_children(self, plan, config)?))
- }
+ _ => Ok(None),
}
}
fn name(&self) -> &str {
"scalar_subquery_to_join"
}
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::TopDown)
+ }
}
/// Takes a query like:
@@ -408,7 +399,11 @@ mod tests {
\n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -444,20 +439,24 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: customer.c_acctbal < __sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\
- \n Inner Join: customer.c_custkey = __sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\
+ \n Filter: customer.c_acctbal < __sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\
+ \n Inner Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n SubqueryAlias: __sq_2 [o_custkey:Int64, __value:Float64;N]\
+ \n SubqueryAlias: __sq_1 [o_custkey:Int64, __value:Float64;N]\
\n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value [o_custkey:Int64, __value:Float64;N]\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]] [o_custkey:Int64, SUM(orders.o_totalprice):Float64;N]\
- \n Filter: orders.o_totalprice < __sq_1.__value [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\
- \n Inner Join: orders.o_orderkey = __sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\
+ \n Filter: orders.o_totalprice < __sq_2.__value [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\
+ \n Inner Join: orders.o_orderkey = __sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n SubqueryAlias: __sq_1 [l_orderkey:Int64, __value:Float64;N]\
+ \n SubqueryAlias: __sq_2 [l_orderkey:Int64, __value:Float64;N]\
\n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS __value [l_orderkey:Int64, __value:Float64;N]\
\n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]] [l_orderkey:Int64, SUM(lineitem.l_extendedprice):Float64;N]\
\n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -490,7 +489,11 @@ mod tests {
\n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -520,7 +523,11 @@ mod tests {
\n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\
\n Filter: customer.c_custkey = customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -550,7 +557,11 @@ mod tests {
\n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -572,7 +583,7 @@ mod tests {
let expected = r#"only joins on column equality are presently supported"#;
- assert_optimizer_err(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimizer_err(Arc::new(ScalarSubqueryToJoin::new()), &plan, expected);
Ok(())
}
@@ -593,7 +604,7 @@ mod tests {
.build()?;
let expected = r#"can't optimize < column comparison"#;
- assert_optimizer_err(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimizer_err(Arc::new(ScalarSubqueryToJoin::new()), &plan, expected);
Ok(())
}
@@ -618,7 +629,7 @@ mod tests {
.build()?;
let expected = r#"Optimizing disjunctions not supported!"#;
- assert_optimizer_err(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimizer_err(Arc::new(ScalarSubqueryToJoin::new()), &plan, expected);
Ok(())
}
@@ -644,7 +655,11 @@ mod tests {
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -670,7 +685,11 @@ mod tests {
let expected = r#""#;
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -694,7 +713,7 @@ mod tests {
.build()?;
let expected = r#"exactly one expression should be projected"#;
- assert_optimizer_err(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimizer_err(Arc::new(ScalarSubqueryToJoin::new()), &plan, expected);
Ok(())
}
@@ -728,7 +747,11 @@ mod tests {
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -760,7 +783,11 @@ mod tests {
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -793,7 +820,11 @@ mod tests {
Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -822,7 +853,11 @@ mod tests {
\n Aggregate: groupBy=[[sq.a]], aggr=[[MIN(sq.c)]] [a:UInt32, MIN(sq.c):UInt32;N]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -850,7 +885,11 @@ mod tests {
\n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
@@ -877,7 +916,11 @@ mod tests {
\n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected);
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ScalarSubqueryToJoin::new()),
+ &plan,
+ expected,
+ );
Ok(())
}
}
diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs
index 462b94dd0..a51c2ec29 100644
--- a/datafusion/optimizer/src/test/mod.rs
+++ b/datafusion/optimizer/src/test/mod.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use crate::optimizer::Optimizer;
use crate::{OptimizerContext, OptimizerRule};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
@@ -102,24 +103,53 @@ pub fn get_tpch_table_schema(table: &str) -> Schema {
}
pub fn assert_optimized_plan_eq(
- rule: &dyn OptimizerRule,
+ rule: Arc<dyn OptimizerRule + Send + Sync>,
+ plan: &LogicalPlan,
+ expected: &str,
+) -> Result<()> {
+ let optimizer = Optimizer::with_rules(vec![rule]);
+ let optimized_plan = optimizer
+ .optimize_recursively(
+ optimizer.rules.get(0).unwrap(),
+ plan,
+ &OptimizerContext::new(),
+ )?
+ .unwrap();
+ let formatted_plan = format!("{:?}", optimized_plan);
+ assert_eq!(formatted_plan, expected);
+
+ Ok(())
+}
+
+pub fn assert_optimized_plan_eq_display_indent(
+ rule: Arc<dyn OptimizerRule + Send + Sync>,
plan: &LogicalPlan,
expected: &str,
) {
- let optimized_plan = rule
- .try_optimize(plan, &OptimizerContext::new())
- .unwrap()
- .expect("failed to optimize plan");
+ let optimizer = Optimizer::with_rules(vec![rule]);
+ let optimized_plan = optimizer
+ .optimize_recursively(
+ optimizer.rules.get(0).unwrap(),
+ plan,
+ &OptimizerContext::new(),
+ )
+ .expect("failed to optimize plan")
+ .unwrap_or_else(|| plan.clone());
let formatted_plan = format!("{}", optimized_plan.display_indent_schema());
assert_eq!(formatted_plan, expected);
}
pub fn assert_optimizer_err(
- rule: &dyn OptimizerRule,
+ rule: Arc<dyn OptimizerRule + Send + Sync>,
plan: &LogicalPlan,
expected: &str,
) {
- let res = rule.try_optimize(plan, &OptimizerContext::new());
+ let optimizer = Optimizer::with_rules(vec![rule]);
+ let res = optimizer.optimize_recursively(
+ optimizer.rules.get(0).unwrap(),
+ plan,
+ &OptimizerContext::new(),
+ );
match res {
Ok(plan) => assert_eq!(format!("{}", plan.unwrap().display_indent()), "An error"),
Err(ref e) => {
@@ -131,13 +161,21 @@ pub fn assert_optimizer_err(
}
}
-pub fn assert_optimization_skipped(rule: &dyn OptimizerRule, plan: &LogicalPlan) {
- let new_plan = rule
- .try_optimize(plan, &OptimizerContext::new())
- .unwrap()
- .unwrap();
+pub fn assert_optimization_skipped(
+ rule: Arc<dyn OptimizerRule + Send + Sync>,
+ plan: &LogicalPlan,
+) -> Result<()> {
+ let optimizer = Optimizer::with_rules(vec![rule]);
+ let new_plan = optimizer
+ .optimize_recursively(
+ optimizer.rules.get(0).unwrap(),
+ plan,
+ &OptimizerContext::new(),
+ )?
+ .unwrap_or_else(|| plan.clone());
assert_eq!(
format!("{}", plan.display_indent()),
format!("{}", new_plan.display_indent())
);
+ Ok(())
}