You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/22 22:25:39 UTC
[arrow-datafusion] branch master updated: Add rule to reimplement `Eliminate cross join` and remove it in planner (#4185)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new d355f69aa Add rule to reimplement `Eliminate cross join` and remove it in planner (#4185)
d355f69aa is described below
commit d355f69aae2cc951cfd021e5c0b690861ba0c4ac
Author: jakevin <ja...@gmail.com>
AuthorDate: Wed Nov 23 06:25:34 2022 +0800
Add rule to reimplement `Eliminate cross join` and remove it in planner (#4185)
* reimplement eliminate_cross_join
* add test for subquery alias and projection alias.
* add test
* fix fmt
* review
* fmt
* fix conflict
---
benchmarks/expected-plans/q2.txt | 43 +--
benchmarks/expected-plans/q8.txt | 41 +--
benchmarks/expected-plans/q9.txt | 25 +-
datafusion/core/tests/sql/subqueries.rs | 43 +--
...educe_cross_join.rs => eliminate_cross_join.rs} | 363 +++++++++++++--------
datafusion/optimizer/src/lib.rs | 2 +-
datafusion/optimizer/src/optimizer.rs | 2 +-
datafusion/optimizer/tests/integration-test.rs | 36 ++
datafusion/sql/src/planner.rs | 293 +++--------------
9 files changed, 404 insertions(+), 444 deletions(-)
diff --git a/benchmarks/expected-plans/q2.txt b/benchmarks/expected-plans/q2.txt
index c5f6fb0fd..e97305509 100644
--- a/benchmarks/expected-plans/q2.txt
+++ b/benchmarks/expected-plans/q2.txt
@@ -1,24 +1,25 @@
Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST
Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment
- Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value
- Inner Join: nation.n_regionkey = region.r_regionkey
- Inner Join: supplier.s_nationkey = nation.n_nationkey
- Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
- Inner Join: part.p_partkey = partsupp.ps_partkey
- Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS")
- TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size]
- TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
- TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
- TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
- Filter: region.r_name = Utf8("EUROPE")
- TableScan: region projection=[r_regionkey, r_name]
- Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1
- Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]]
- Inner Join: nation.n_regionkey = region.r_regionkey
- Inner Join: supplier.s_nationkey = nation.n_nationkey
- Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+ Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, nation.n_name
+ Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value
+ Inner Join: nation.n_regionkey = region.r_regionkey
+ Inner Join: supplier.s_nationkey = nation.n_nationkey
+ Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+ Inner Join: part.p_partkey = partsupp.ps_partkey
+ Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS")
+ TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size]
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
- TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
- TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
- Filter: region.r_name = Utf8("EUROPE")
- TableScan: region projection=[r_regionkey, r_name]
\ No newline at end of file
+ TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name]
+ Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1
+ Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]]
+ Inner Join: nation.n_regionkey = region.r_regionkey
+ Inner Join: supplier.s_nationkey = nation.n_nationkey
+ Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
+ TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name]
\ No newline at end of file
diff --git a/benchmarks/expected-plans/q8.txt b/benchmarks/expected-plans/q8.txt
index 3f5a87680..1b8d08ef8 100644
--- a/benchmarks/expected-plans/q8.txt
+++ b/benchmarks/expected-plans/q8.txt
@@ -3,23 +3,24 @@ Sort: all_nations.o_year ASC NULLS LAST
Aggregate: groupBy=[[all_nations.o_year]], aggr=[[SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), SUM(all_nations.volume)]]
Projection: o_year, volume, nation, alias=all_nations
Projection: datepart(Utf8("YEAR"), orders.o_orderdate) AS o_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume, n2.n_name AS nation
- Inner Join: n1.n_regionkey = region.r_regionkey
- Inner Join: supplier.s_nationkey = n2.n_nationkey
- Inner Join: customer.c_nationkey = n1.n_nationkey
- Inner Join: orders.o_custkey = customer.c_custkey
- Inner Join: lineitem.l_orderkey = orders.o_orderkey
- Inner Join: lineitem.l_suppkey = supplier.s_suppkey
- Inner Join: part.p_partkey = lineitem.l_partkey
- Filter: part.p_type = Utf8("ECONOMY ANODIZED STEEL")
- TableScan: part projection=[p_partkey, p_type]
- TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount]
- TableScan: supplier projection=[s_suppkey, s_nationkey]
- Filter: orders.o_orderdate >= Date32("9131") AND orders.o_orderdate <= Date32("9861")
- TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate]
- TableScan: customer projection=[c_custkey, c_nationkey]
- SubqueryAlias: n1
- TableScan: nation projection=[n_nationkey, n_regionkey]
- SubqueryAlias: n2
- TableScan: nation projection=[n_nationkey, n_name]
- Filter: region.r_name = Utf8("AMERICA")
- TableScan: region projection=[r_regionkey, r_name]
\ No newline at end of file
+ Projection: lineitem.l_extendedprice, lineitem.l_discount, orders.o_orderdate, n2.n_name
+ Inner Join: n1.n_regionkey = region.r_regionkey
+ Inner Join: supplier.s_nationkey = n2.n_nationkey
+ Inner Join: customer.c_nationkey = n1.n_nationkey
+ Inner Join: orders.o_custkey = customer.c_custkey
+ Inner Join: lineitem.l_orderkey = orders.o_orderkey
+ Inner Join: lineitem.l_suppkey = supplier.s_suppkey
+ Inner Join: part.p_partkey = lineitem.l_partkey
+ Filter: part.p_type = Utf8("ECONOMY ANODIZED STEEL")
+ TableScan: part projection=[p_partkey, p_type]
+ TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount]
+ TableScan: supplier projection=[s_suppkey, s_nationkey]
+ Filter: orders.o_orderdate >= Date32("9131") AND orders.o_orderdate <= Date32("9861")
+ TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate]
+ TableScan: customer projection=[c_custkey, c_nationkey]
+ SubqueryAlias: n1
+ TableScan: nation projection=[n_nationkey, n_regionkey]
+ SubqueryAlias: n2
+ TableScan: nation projection=[n_nationkey, n_name]
+ Filter: region.r_name = Utf8("AMERICA")
+ TableScan: region projection=[r_regionkey, r_name]
\ No newline at end of file
diff --git a/benchmarks/expected-plans/q9.txt b/benchmarks/expected-plans/q9.txt
index 339db7017..ae7d4f194 100644
--- a/benchmarks/expected-plans/q9.txt
+++ b/benchmarks/expected-plans/q9.txt
@@ -3,15 +3,16 @@ Sort: profit.nation ASC NULLS LAST, profit.o_year DESC NULLS FIRST
Aggregate: groupBy=[[profit.nation, profit.o_year]], aggr=[[SUM(profit.amount)]]
Projection: nation, o_year, amount, alias=profit
Projection: nation.n_name AS nation, datepart(Utf8("YEAR"), orders.o_orderdate) AS o_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) - CAST(partsupp.ps_supplycost * lineitem.l_quantity AS Decimal128(38, 4)) AS amount
- Inner Join: supplier.s_nationkey = nation.n_nationkey
- Inner Join: lineitem.l_orderkey = orders.o_orderkey
- Inner Join: lineitem.l_suppkey = partsupp.ps_suppkey, lineitem.l_partkey = partsupp.ps_partkey
- Inner Join: lineitem.l_suppkey = supplier.s_suppkey
- Inner Join: part.p_partkey = lineitem.l_partkey
- Filter: part.p_name LIKE Utf8("%green%")
- TableScan: part projection=[p_partkey, p_name]
- TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount]
- TableScan: supplier projection=[s_suppkey, s_nationkey]
- TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
- TableScan: orders projection=[o_orderkey, o_orderdate]
- TableScan: nation projection=[n_nationkey, n_name]
\ No newline at end of file
+ Projection: lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, partsupp.ps_supplycost, orders.o_orderdate, nation.n_name
+ Inner Join: supplier.s_nationkey = nation.n_nationkey
+ Inner Join: lineitem.l_orderkey = orders.o_orderkey
+ Inner Join: lineitem.l_suppkey = partsupp.ps_suppkey, lineitem.l_partkey = partsupp.ps_partkey
+ Inner Join: lineitem.l_suppkey = supplier.s_suppkey
+ Inner Join: part.p_partkey = lineitem.l_partkey
+ Filter: part.p_name LIKE Utf8("%green%")
+ TableScan: part projection=[p_partkey, p_name]
+ TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount]
+ TableScan: supplier projection=[s_suppkey, s_nationkey]
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
+ TableScan: orders projection=[o_orderkey, o_orderdate]
+ TableScan: nation projection=[n_nationkey, n_name]
\ No newline at end of file
diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
index 064ef3a35..98bf56a02 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -141,28 +141,29 @@ order by s_acctbal desc, n_name, s_name, p_partkey;"#;
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST
Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment
- Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value
- Inner Join: nation.n_regionkey = region.r_regionkey
- Inner Join: supplier.s_nationkey = nation.n_nationkey
- Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
- Inner Join: part.p_partkey = partsupp.ps_partkey
- Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS")
- TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")]
- TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
- TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
- TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
- Filter: region.r_name = Utf8("EUROPE")
- TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]
- Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1
- Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]]
- Inner Join: nation.n_regionkey = region.r_regionkey
- Inner Join: supplier.s_nationkey = nation.n_nationkey
- Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+ Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, nation.n_name
+ Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = __sq_1.__value
+ Inner Join: nation.n_regionkey = region.r_regionkey
+ Inner Join: supplier.s_nationkey = nation.n_nationkey
+ Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+ Inner Join: part.p_partkey = partsupp.ps_partkey
+ Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS")
+ TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")]
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
- TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
- TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
- Filter: region.r_name = Utf8("EUROPE")
- TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]"#
+ TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]
+ Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1
+ Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]]
+ Inner Join: nation.n_regionkey = region.r_regionkey
+ Inner Join: supplier.s_nationkey = nation.n_nationkey
+ Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
+ TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")]"#
.to_string();
assert_eq!(actual, expected);
diff --git a/datafusion/optimizer/src/reduce_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs
similarity index 77%
rename from datafusion/optimizer/src/reduce_cross_join.rs
rename to datafusion/optimizer/src/eliminate_cross_join.rs
index 45230ebb2..23e80ee54 100644
--- a/datafusion/optimizer/src/reduce_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -16,19 +16,19 @@
// under the License.
//! Optimizer rule to reduce cross join to inner join if join predicates are available in filters.
-use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{Column, Result};
+use crate::{utils, OptimizerConfig, OptimizerRule};
+use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::{
and,
expr::BinaryExpr,
logical_plan::{CrossJoin, Filter, Join, JoinType, LogicalPlan},
or,
utils::can_hash,
- utils::from_plan,
+ Projection,
};
use datafusion_expr::{Expr, Operator};
-use std::collections::HashSet;
+use std::collections::{HashMap, HashSet};
//use std::collections::HashMap;
use datafusion_expr::logical_plan::JoinConstraint;
@@ -44,16 +44,92 @@ impl ReduceCrossJoin {
}
}
+/// Attempt to reorder join tp reduce cross joins to inner joins.
+/// for queries:
+/// 'select ... from a, b where a.x = b.y and b.xx = 100;'
+/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
+/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
+/// or (a.x = b.y and b.xx = 200 and a.z=c.z);'
+/// For above queries, the join predicate is available in filters and they are moved to
+/// join nodes appropriately
+/// This fix helps to improve the performance of TPCH Q19. issue#78
+///
impl OptimizerRule for ReduceCrossJoin {
fn optimize(
&self,
plan: &LogicalPlan,
_optimizer_config: &mut OptimizerConfig,
) -> Result<LogicalPlan> {
- let mut possible_join_keys: Vec<(Column, Column)> = vec![];
- let mut all_join_keys = HashSet::new();
+ match plan {
+ LogicalPlan::Filter(filter) => {
+ let input = (**filter.input()).clone();
+
+ let mut possible_join_keys: Vec<(Column, Column)> = vec![];
+ let mut all_inputs: Vec<LogicalPlan> = vec![];
+ match &input {
+ LogicalPlan::Join(join) if (join.join_type == JoinType::Inner) => {
+ flatten_join_inputs(
+ &input,
+ &mut possible_join_keys,
+ &mut all_inputs,
+ )?;
+ }
+ LogicalPlan::CrossJoin(_) => {
+ flatten_join_inputs(
+ &input,
+ &mut possible_join_keys,
+ &mut all_inputs,
+ )?;
+ }
+ _ => {
+ return utils::optimize_children(self, plan, _optimizer_config);
+ }
+ }
+
+ let predicate = filter.predicate();
+ // join keys are handled locally
+ let mut all_join_keys: HashSet<(Column, Column)> = HashSet::new();
+
+ extract_possible_join_keys(predicate, &mut possible_join_keys);
+
+ let mut left = all_inputs.remove(0);
+ while !all_inputs.is_empty() {
+ left = find_inner_join(
+ &left,
+ &mut all_inputs,
+ &mut possible_join_keys,
+ &mut all_join_keys,
+ )?;
+ }
- reduce_cross_join(self, plan, &mut possible_join_keys, &mut all_join_keys)
+ left = utils::optimize_children(self, &left, _optimizer_config)?;
+ if plan.schema() != left.schema() {
+ left = LogicalPlan::Projection(Projection::new_from_schema(
+ Arc::new(left.clone()),
+ plan.schema().clone(),
+ ));
+ }
+
+ // if there are no join keys then do nothing.
+ if all_join_keys.is_empty() {
+ Ok(LogicalPlan::Filter(Filter::try_new(
+ predicate.clone(),
+ Arc::new(left),
+ )?))
+ } else {
+ // remove join expressions from filter
+ match remove_join_expressions(predicate, &all_join_keys)? {
+ Some(filter_expr) => Ok(LogicalPlan::Filter(Filter::try_new(
+ filter_expr,
+ Arc::new(left),
+ )?)),
+ _ => Ok(left),
+ }
+ }
+ }
+
+ _ => utils::optimize_children(self, plan, _optimizer_config),
+ }
}
fn name(&self) -> &str {
@@ -61,126 +137,108 @@ impl OptimizerRule for ReduceCrossJoin {
}
}
-/// Attempt to reduce cross joins to inner joins.
-/// for queries:
-/// 'select ... from a, b where a.x = b.y and b.xx = 100;'
-/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);'
-/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
-/// or (a.x = b.y and b.xx = 200 and a.z=c.z);'
-/// For above queries, the join predicate is available in filters and they are moved to
-/// join nodes appropriately
-/// This fix helps to improve the performance of TPCH Q19. issue#78
-///
-fn reduce_cross_join(
- _optimizer: &ReduceCrossJoin,
+fn flatten_join_inputs(
plan: &LogicalPlan,
possible_join_keys: &mut Vec<(Column, Column)>,
- all_join_keys: &mut HashSet<(Column, Column)>,
-) -> Result<LogicalPlan> {
- match plan {
- LogicalPlan::Filter(filter) => {
- let input = filter.input();
- let predicate = filter.predicate();
- // join keys are handled locally
- let mut new_possible_join_keys: Vec<(Column, Column)> = vec![];
- let mut new_all_join_keys = HashSet::new();
-
- extract_possible_join_keys(predicate, &mut new_possible_join_keys);
-
- let new_plan = reduce_cross_join(
- _optimizer,
- input,
- &mut new_possible_join_keys,
- &mut new_all_join_keys,
- )?;
-
- // if there are no join keys then do nothing.
- if new_all_join_keys.is_empty() {
- Ok(LogicalPlan::Filter(Filter::try_new(
- predicate.clone(),
- Arc::new(new_plan),
- )?))
- } else {
- // remove join expressions from filter
- match remove_join_expressions(predicate, &new_all_join_keys)? {
- Some(filter_expr) => Ok(LogicalPlan::Filter(Filter::try_new(
- filter_expr,
- Arc::new(new_plan),
- )?)),
- _ => Ok(new_plan),
- }
+ all_inputs: &mut Vec<LogicalPlan>,
+) -> Result<()> {
+ let children = match plan {
+ LogicalPlan::Join(join) => {
+ for join_keys in join.on.iter() {
+ possible_join_keys.push(join_keys.clone());
}
+ let left = &*(join.left);
+ let right = &*(join.right);
+ Ok::<Vec<&LogicalPlan>, DataFusionError>(vec![left, right])
}
- LogicalPlan::CrossJoin(cross_join) => {
- let left_plan = reduce_cross_join(
- _optimizer,
- &cross_join.left,
- possible_join_keys,
- all_join_keys,
- )?;
- let right_plan = reduce_cross_join(
- _optimizer,
- &cross_join.right,
- possible_join_keys,
- all_join_keys,
- )?;
- // can we find a match?
- let left_schema = left_plan.schema();
- let right_schema = right_plan.schema();
- let mut join_keys = vec![];
-
- for (l, r) in possible_join_keys {
- if left_schema.field_from_column(l).is_ok()
- && right_schema.field_from_column(r).is_ok()
- && can_hash(left_schema.field_from_column(l).unwrap().data_type())
- {
- join_keys.push((l.clone(), r.clone()));
- } else if left_schema.field_from_column(r).is_ok()
- && right_schema.field_from_column(l).is_ok()
- && can_hash(left_schema.field_from_column(r).unwrap().data_type())
- {
- join_keys.push((r.clone(), l.clone()));
+ LogicalPlan::CrossJoin(join) => {
+ let left = &*(join.left);
+ let right = &*(join.right);
+ Ok::<Vec<&LogicalPlan>, DataFusionError>(vec![left, right])
+ }
+ _ => {
+ return Err(DataFusionError::Plan(
+ "flatten_join_inputs just can call join/cross_join".to_string(),
+ ));
+ }
+ }?;
+
+ for child in children.iter() {
+ match *child {
+ LogicalPlan::Join(left_join) => {
+ if left_join.join_type == JoinType::Inner {
+ flatten_join_inputs(child, possible_join_keys, all_inputs)?;
+ } else {
+ all_inputs.push((*child).clone());
}
}
+ LogicalPlan::CrossJoin(_) => {
+ flatten_join_inputs(child, possible_join_keys, all_inputs)?;
+ }
+ _ => all_inputs.push((*child).clone()),
+ }
+ }
+ Ok(())
+}
- // if there are no join keys then do nothing.
- if join_keys.is_empty() {
- Ok(LogicalPlan::CrossJoin(CrossJoin {
- left: Arc::new(left_plan),
- right: Arc::new(right_plan),
- schema: cross_join.schema.clone(),
- }))
- } else {
- // Keep track of join keys being pushed to Join nodes
- all_join_keys.extend(join_keys.clone());
-
- Ok(LogicalPlan::Join(Join {
- left: Arc::new(left_plan),
- right: Arc::new(right_plan),
- join_type: JoinType::Inner,
- join_constraint: JoinConstraint::On,
- on: join_keys,
- filter: None,
- schema: cross_join.schema.clone(),
- null_equals_null: false,
- }))
+fn find_inner_join(
+ left: &LogicalPlan,
+ rights: &mut Vec<LogicalPlan>,
+ possible_join_keys: &mut Vec<(Column, Column)>,
+ all_join_keys: &mut HashSet<(Column, Column)>,
+) -> Result<LogicalPlan> {
+ for (i, right) in rights.iter().enumerate() {
+ let mut join_keys = vec![];
+
+ for (l, r) in &mut *possible_join_keys {
+ if left.schema().field_from_column(l).is_ok()
+ && right.schema().field_from_column(r).is_ok()
+ && can_hash(left.schema().field_from_column(l).unwrap().data_type())
+ {
+ join_keys.push((l.clone(), r.clone()));
+ } else if left.schema().field_from_column(r).is_ok()
+ && right.schema().field_from_column(l).is_ok()
+ && can_hash(left.schema().field_from_column(r).unwrap().data_type())
+ {
+ join_keys.push((r.clone(), l.clone()));
}
}
- _ => {
- let expr = plan.expressions();
-
- // apply the optimization to all inputs of the plan
- let inputs = plan.inputs();
- let new_inputs = inputs
- .iter()
- .map(|plan| {
- reduce_cross_join(_optimizer, plan, possible_join_keys, all_join_keys)
- })
- .collect::<Result<Vec<_>>>()?;
-
- from_plan(plan, &expr, &new_inputs)
+
+ if !join_keys.is_empty() {
+ all_join_keys.extend(join_keys.clone());
+ let right = rights.remove(i);
+ let join_schema = Arc::new(build_join_schema(left, &right)?);
+ return Ok(LogicalPlan::Join(Join {
+ left: Arc::new(left.clone()),
+ right: Arc::new(right),
+ join_type: JoinType::Inner,
+ join_constraint: JoinConstraint::On,
+ on: join_keys,
+ filter: None,
+ schema: join_schema,
+ null_equals_null: false,
+ }));
}
}
+ let right = rights.remove(0);
+ let join_schema = Arc::new(build_join_schema(left, &right)?);
+
+ Ok(LogicalPlan::CrossJoin(CrossJoin {
+ left: Arc::new(left.clone()),
+ right: Arc::new(right),
+ schema: join_schema,
+ }))
+}
+
+fn build_join_schema(left: &LogicalPlan, right: &LogicalPlan) -> Result<DFSchema> {
+ // build join schema
+ let mut fields = vec![];
+ let mut metadata = HashMap::new();
+ fields.extend(left.schema().fields().clone());
+ fields.extend(right.schema().fields().clone());
+ metadata.extend(left.schema().metadata().clone());
+ metadata.extend(right.schema().metadata().clone());
+ DFSchema::new_with_metadata(fields, metadata)
}
fn intersect(
@@ -475,6 +533,53 @@ mod tests {
Ok(())
}
+ #[test]
+ /// ```txt
+ /// filter: a.id = b.id and a.id = c.id
+ /// cross_join a (bc)
+ /// cross_join b c
+ /// ```
+ /// Without reorder, it will be
+ /// ```txt
+ /// inner_join a (bc) on a.id = b.id and a.id = c.id
+ /// cross_join b c
+ /// ```
+ /// Reorder it to be
+ /// ```txt
+ /// inner_join (ab)c and a.id = c.id
+ /// inner_join a b on a.id = b.id
+ /// ```
+ fn reorder_join_to_reduce_cross_join_multi_tables() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+ let t3 = test_table_scan_with_name("t3")?;
+
+ // could reduce to inner join
+ let plan = LogicalPlanBuilder::from(t1)
+ .cross_join(&t2)?
+ .cross_join(&t3)?
+ .filter(binary_expr(
+ binary_expr(col("t3.a").eq(col("t1.a")), And, col("t3.c").lt(lit(15u32))),
+ And,
+ binary_expr(col("t3.a").eq(col("t2.a")), And, col("t3.b").lt(lit(15u32))),
+ ))?
+ .build()?;
+
+ let expected = vec![
+ "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+ ];
+
+ assert_optimized_plan_eq(&plan, expected);
+
+ Ok(())
+ }
+
#[test]
fn reduce_cross_join_multi_tables() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
@@ -849,14 +954,14 @@ mod tests {
let expected = vec![
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
- " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
- " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
- " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(&plan, expected);
@@ -937,13 +1042,13 @@ mod tests {
let expected = vec![
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
- " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
- " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
- " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(&plan, expected);
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index e62cbbd73..467ec3b24 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -18,6 +18,7 @@
pub mod common_subexpr_eliminate;
pub mod decorrelate_where_exists;
pub mod decorrelate_where_in;
+pub mod eliminate_cross_join;
pub mod eliminate_filter;
pub mod eliminate_limit;
pub mod filter_null_join_keys;
@@ -27,7 +28,6 @@ pub mod limit_push_down;
pub mod optimizer;
pub mod projection_push_down;
pub mod propagate_empty_relation;
-pub mod reduce_cross_join;
pub mod reduce_outer_join;
pub mod scalar_subquery_to_join;
pub mod simplify_expressions;
diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs
index f09d2ee24..3614c8a4f 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -20,6 +20,7 @@
use crate::common_subexpr_eliminate::CommonSubexprEliminate;
use crate::decorrelate_where_exists::DecorrelateWhereExists;
use crate::decorrelate_where_in::DecorrelateWhereIn;
+use crate::eliminate_cross_join::ReduceCrossJoin;
use crate::eliminate_filter::EliminateFilter;
use crate::eliminate_limit::EliminateLimit;
use crate::filter_null_join_keys::FilterNullJoinKeys;
@@ -28,7 +29,6 @@ use crate::inline_table_scan::InlineTableScan;
use crate::limit_push_down::LimitPushDown;
use crate::projection_push_down::ProjectionPushDown;
use crate::propagate_empty_relation::PropagateEmptyRelation;
-use crate::reduce_cross_join::ReduceCrossJoin;
use crate::reduce_outer_join::ReduceOuterJoin;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::scalar_subquery_to_join::ScalarSubqueryToJoin;
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index b2932963e..fb27ed5ed 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -267,6 +267,42 @@ fn propagate_empty_relation() {
assert_eq!(expected, format!("{:?}", plan));
}
+#[test]
+fn join_keys_in_subquery_alias() {
+ let sql = "SELECT * FROM test AS A, ( SELECT col_int32 as key FROM test ) AS B where A.col_int32 = B.key;";
+ let plan = test_sql(sql).unwrap();
+ let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\
+ \n Inner Join: a.col_int32 = b.key\
+ \n Filter: a.col_int32 IS NOT NULL\
+ \n SubqueryAlias: a\
+ \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\
+ \n Projection: key, alias=b\
+ \n Projection: test.col_int32 AS key\
+ \n Filter: test.col_int32 IS NOT NULL\
+ \n TableScan: test projection=[col_int32]";
+ assert_eq!(expected, format!("{:?}", plan));
+}
+
+#[test]
+fn join_keys_in_subquery_alias_1() {
+ let sql = "SELECT * FROM test AS A, ( SELECT test.col_int32 AS key FROM test JOIN test AS C on test.col_int32 = C.col_int32 ) AS B where A.col_int32 = B.key;";
+ let plan = test_sql(sql).unwrap();
+ let expected = "Projection: a.col_int32, a.col_uint32, a.col_utf8, a.col_date32, a.col_date64, a.col_ts_nano_none, a.col_ts_nano_utc, b.key\
+ \n Inner Join: a.col_int32 = b.key\
+ \n Filter: a.col_int32 IS NOT NULL\
+ \n SubqueryAlias: a\
+ \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\
+ \n Projection: key, alias=b\
+ \n Projection: test.col_int32 AS key\
+ \n Inner Join: test.col_int32 = c.col_int32\
+ \n Filter: test.col_int32 IS NOT NULL\
+ \n TableScan: test projection=[col_int32]\
+ \n Filter: c.col_int32 IS NOT NULL\
+ \n SubqueryAlias: c\
+ \n TableScan: test projection=[col_int32]";
+ assert_eq!(expected, format!("{:?}", plan));
+}
+
fn test_sql(sql: &str) -> Result<LogicalPlan> {
// parse the SQL
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 8e3e6d911..1e09054c8 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -57,12 +57,12 @@ use datafusion_expr::logical_plan::{
use datafusion_expr::logical_plan::{Filter, Subquery};
use datafusion_expr::utils::{
can_hash, check_all_column_from_schema, expand_qualified_wildcard, expand_wildcard,
- expr_as_column_expr, find_aggregate_exprs, find_column_exprs, find_window_exprs,
- COUNT_STAR_EXPANSION,
+ expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_column_exprs,
+ find_window_exprs, COUNT_STAR_EXPANSION,
};
use datafusion_expr::Expr::Alias;
use datafusion_expr::{
- and, cast, col, lit, AggregateFunction, AggregateUDF, Expr, ExprSchemable,
+ cast, col, lit, AggregateFunction, AggregateUDF, Expr, ExprSchemable,
GetIndexedField, Operator, ScalarUDF, WindowFrame, WindowFrameUnits,
};
use datafusion_expr::{
@@ -948,166 +948,52 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
outer_query_schema: Option<&DFSchema>,
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<LogicalPlan> {
+ let cross_join_plan = if plans.len() == 1 {
+ plans[0].clone()
+ } else {
+ let mut left = plans[0].clone();
+ for right in plans.iter().skip(1) {
+ left = LogicalPlanBuilder::from(left).cross_join(right)?.build()?;
+ }
+ left
+ };
match selection {
Some(predicate_expr) => {
- // build join schema
let mut fields = vec![];
- let mut metadata = std::collections::HashMap::new();
+ let mut metadata = HashMap::new();
for plan in &plans {
fields.extend_from_slice(plan.schema().fields());
metadata.extend(plan.schema().metadata().clone());
}
+
let mut join_schema = DFSchema::new_with_metadata(fields, metadata)?;
+ let mut all_schemas: Vec<DFSchemaRef> = vec![];
+ for plan in plans {
+ for schema in plan.all_schemas() {
+ all_schemas.push(schema.clone());
+ }
+ }
if let Some(outer) = outer_query_schema {
+ all_schemas.push(Arc::new(outer.clone()));
join_schema.merge(outer);
}
+ let x: Vec<&DFSchemaRef> = all_schemas.iter().collect();
let filter_expr = self.sql_to_rex(predicate_expr, &join_schema, ctes)?;
+ let mut using_columns = HashSet::new();
+ expr_to_columns(&filter_expr, &mut using_columns)?;
+ let filter_expr = normalize_col_with_schemas(
+ filter_expr,
+ x.as_slice(),
+ &[using_columns],
+ )?;
- // look for expressions of the form `<column> = <column>`
- let mut possible_join_keys = vec![];
- extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?;
-
- let mut all_join_keys = HashSet::new();
-
- let orig_plans = plans.clone();
- let mut plans = plans.into_iter();
- let mut left = plans.next().unwrap(); // have at least one plan
-
- // List of the plans that have not yet been joined
- let mut remaining_plans: Vec<Option<LogicalPlan>> =
- plans.into_iter().map(Some).collect();
-
- // Take from the list of remaining plans,
- loop {
- let mut join_keys = vec![];
-
- // Search all remaining plans for the next to
- // join. Prefer the first one that has a join
- // predicate in the predicate lists
- let plan_with_idx =
- remaining_plans.iter().enumerate().find(|(_idx, plan)| {
- // skip plans that have been joined already
- let plan = if let Some(plan) = plan {
- plan
- } else {
- return false;
- };
-
- // can we find a match?
- let left_schema = left.schema();
- let right_schema = plan.schema();
- for (l, r) in &possible_join_keys {
- if left_schema.field_from_column(l).is_ok()
- && right_schema.field_from_column(r).is_ok()
- && can_hash(
- left_schema
- .field_from_column(l)
- .unwrap() // the result must be OK
- .data_type(),
- )
- {
- join_keys.push((l.clone(), r.clone()));
- } else if left_schema.field_from_column(r).is_ok()
- && right_schema.field_from_column(l).is_ok()
- && can_hash(
- left_schema
- .field_from_column(r)
- .unwrap() // the result must be OK
- .data_type(),
- )
- {
- join_keys.push((r.clone(), l.clone()));
- }
- }
- // stop if we found join keys
- !join_keys.is_empty()
- });
-
- // If we did not find join keys, either there are
- // no more plans, or we can't find any plans that
- // can be joined with predicates
- if join_keys.is_empty() {
- assert!(plan_with_idx.is_none());
-
- // pick the first non null plan to join
- let plan_with_idx = remaining_plans
- .iter()
- .enumerate()
- .find(|(_idx, plan)| plan.is_some());
- if let Some((idx, _)) = plan_with_idx {
- let plan = std::mem::take(&mut remaining_plans[idx]).unwrap();
- left = LogicalPlanBuilder::from(left)
- .cross_join(&plan)?
- .build()?;
- } else {
- // no more plans to join
- break;
- }
- } else {
- // have a plan
- let (idx, _) = plan_with_idx.expect("found plan node");
- let plan = std::mem::take(&mut remaining_plans[idx]).unwrap();
-
- let left_keys: Vec<Column> =
- join_keys.iter().map(|(l, _)| l.clone()).collect();
- let right_keys: Vec<Column> =
- join_keys.iter().map(|(_, r)| r.clone()).collect();
- let builder = LogicalPlanBuilder::from(left);
- left = builder
- .join(&plan, JoinType::Inner, (left_keys, right_keys), None)?
- .build()?;
- }
-
- all_join_keys.extend(join_keys);
- }
-
- // remove join expressions from filter
- match remove_join_expressions(&filter_expr, &all_join_keys)? {
- Some(filter_expr) => {
- // this logic is adapted from [`LogicalPlanBuilder::filter`] to take
- // the query outer schema into account so that joins in subqueries
- // can reference outer query fields.
- let mut all_schemas: Vec<DFSchemaRef> = vec![];
- for plan in orig_plans {
- for schema in plan.all_schemas() {
- all_schemas.push(schema.clone());
- }
- }
- if let Some(outer_query_schema) = outer_query_schema {
- all_schemas.push(Arc::new(outer_query_schema.clone()));
- }
- let mut join_columns = HashSet::new();
- for (l, r) in &all_join_keys {
- join_columns.insert(l.clone());
- join_columns.insert(r.clone());
- }
- let x: Vec<&DFSchemaRef> = all_schemas.iter().collect();
- let filter_expr = normalize_col_with_schemas(
- filter_expr,
- x.as_slice(),
- &[join_columns],
- )?;
- Ok(LogicalPlan::Filter(Filter::try_new(
- filter_expr,
- Arc::new(left),
- )?))
- }
- _ => Ok(left),
- }
- }
- None => {
- if plans.len() == 1 {
- Ok(plans[0].clone())
- } else {
- let mut left = plans[0].clone();
- for right in plans.iter().skip(1) {
- left =
- LogicalPlanBuilder::from(left).cross_join(right)?.build()?;
- }
- Ok(left)
- }
+ Ok(LogicalPlan::Filter(Filter::try_new(
+ filter_expr,
+ Arc::new(cross_join_plan),
+ )?))
}
+ None => Ok(cross_join_plan),
}
}
@@ -2707,7 +2593,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
| Value::Null
| Value::Placeholder(_) => {
return Err(DataFusionError::Plan(format!(
- "Unspported Value {}",
+ "Unsupported Value {}",
value[0]
)))
}
@@ -2718,14 +2604,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
UnaryOperator::Minus => format!("-{}", expr),
_ => {
return Err(DataFusionError::Plan(format!(
- "Unspported Value {}",
+ "Unsupported Value {}",
value[0]
)))
}
},
_ => {
return Err(DataFusionError::Plan(format!(
- "Unspported Value {}",
+ "Unsupported Value {}",
value[0]
)))
}
@@ -3054,41 +2940,6 @@ pub fn object_name_to_qualifier(sql_table_name: &ObjectName) -> String {
.join(" AND ")
}
-/// Remove join expressions from a filter expression
-fn remove_join_expressions(
- expr: &Expr,
- join_columns: &HashSet<(Column, Column)>,
-) -> Result<Option<Expr>> {
- match expr {
- Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
- Operator::Eq => match (left.as_ref(), right.as_ref()) {
- (Expr::Column(l), Expr::Column(r)) => {
- if join_columns.contains(&(l.clone(), r.clone()))
- || join_columns.contains(&(r.clone(), l.clone()))
- {
- Ok(None)
- } else {
- Ok(Some(expr.clone()))
- }
- }
- _ => Ok(Some(expr.clone())),
- },
- Operator::And => {
- let l = remove_join_expressions(left, join_columns)?;
- let r = remove_join_expressions(right, join_columns)?;
- match (l, r) {
- (Some(ll), Some(rr)) => Ok(Some(and(ll, rr))),
- (Some(ll), _) => Ok(Some(ll)),
- (_, Some(rr)) => Ok(Some(rr)),
- _ => Ok(None),
- }
- }
- _ => Ok(Some(expr.clone())),
- },
- _ => Ok(Some(expr.clone())),
- }
-}
-
/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs
/// Filters matching this pattern are added to `accum`
/// Filters that don't match this pattern are added to `accum_filter`
@@ -3196,30 +3047,6 @@ fn extract_join_keys(
Ok(())
}
-/// Extract join keys from a WHERE clause
-fn extract_possible_join_keys(
- expr: &Expr,
- accum: &mut Vec<(Column, Column)>,
-) -> Result<()> {
- match expr {
- Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
- Operator::Eq => match (left.as_ref(), right.as_ref()) {
- (Expr::Column(l), Expr::Column(r)) => {
- accum.push((l.clone(), r.clone()));
- Ok(())
- }
- _ => Ok(()),
- },
- Operator::And => {
- extract_possible_join_keys(left, accum)?;
- extract_possible_join_keys(right, accum)
- }
- _ => Ok(()),
- },
- _ => Ok(()),
- }
-}
-
/// Wrap projection for a plan, if the join keys contains normal expression.
fn wrap_projection_for_join_if_necessary(
join_keys: &[Expr],
@@ -5486,18 +5313,6 @@ mod tests {
quick_test(sql, expected);
}
- #[test]
- fn cross_join_to_inner_join() {
- let sql = "select person.id from person, orders, lineitem where person.id = lineitem.l_item_id and orders.o_item_id = lineitem.l_description;";
- let expected = "Projection: person.id\
- \n Inner Join: lineitem.l_description = orders.o_item_id\
- \n Inner Join: person.id = lineitem.l_item_id\
- \n TableScan: person\
- \n TableScan: lineitem\
- \n TableScan: orders";
- quick_test(sql, expected);
- }
-
#[test]
fn cross_join_not_to_inner_join() {
let sql = "select person.id from person, orders, lineitem where person.id = person.age;";
@@ -5581,15 +5396,15 @@ mod tests {
AND person.state = p.state)";
let expected = "Projection: person.id\
- \n Filter: EXISTS (<subquery>)\
+ \n Filter: person.id = p.id AND EXISTS (<subquery>)\
\n Subquery:\
\n Projection: person.first_name\
- \n Filter: person.last_name = p.last_name AND person.state = p.state\
- \n Inner Join: person.id = p2.id\
+ \n Filter: person.id = p2.id AND person.last_name = p.last_name AND person.state = p.state\
+ \n CrossJoin:\
\n TableScan: person\
\n SubqueryAlias: p2\
\n TableScan: person\
- \n Inner Join: person.id = p.id\
+ \n CrossJoin:\
\n TableScan: person\
\n SubqueryAlias: p\
\n TableScan: person";
@@ -5675,8 +5490,8 @@ mod tests {
\n Subquery:\
\n Projection: COUNT(UInt8(1))\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
- \n Filter: j2.j2_id = j1.j1_id\
- \n Inner Join: j1.j1_id = j3.j3_id\
+ \n Filter: j2.j2_id = j1.j1_id AND j1.j1_id = j3.j3_id\
+ \n CrossJoin:\
\n TableScan: j1\
\n TableScan: j3\
\n CrossJoin:\
@@ -6090,8 +5905,8 @@ mod tests {
#[test]
fn test_select_join_key_inner_join() {
let sql = "SELECT orders.customer_id * 2, person.id + 10
- FROM person
- INNER JOIN orders
+ FROM person
+ INNER JOIN orders
ON orders.customer_id * 2 = person.id + 10";
let expected = "Projection: orders.customer_id * Int64(2), person.id + Int64(10)\
@@ -6107,9 +5922,9 @@ mod tests {
#[test]
fn test_non_projetion_after_inner_join() {
// There's no need to add projection for left and right, so does adding projection after join.
- let sql = "SELECT person.id, person.age
- FROM person
- INNER JOIN orders
+ let sql = "SELECT person.id, person.age
+ FROM person
+ INNER JOIN orders
ON orders.customer_id = person.id";
let expected = "Projection: person.id, person.age\
@@ -6122,9 +5937,9 @@ mod tests {
#[test]
fn test_duplicated_left_join_key_inner_join() {
// person.id * 2 happen twice in left side.
- let sql = "SELECT person.id, person.age
- FROM person
- INNER JOIN orders
+ let sql = "SELECT person.id, person.age
+ FROM person
+ INNER JOIN orders
ON person.id * 2 = orders.customer_id + 10 and person.id * 2 = orders.order_id";
let expected = "Projection: person.id, person.age\
@@ -6140,9 +5955,9 @@ mod tests {
#[test]
fn test_duplicated_right_join_key_inner_join() {
// orders.customer_id + 10 happen twice in right side.
- let sql = "SELECT person.id, person.age
- FROM person
- INNER JOIN orders
+ let sql = "SELECT person.id, person.age
+ FROM person
+ INNER JOIN orders
ON person.id * 2 = orders.customer_id + 10 and person.id = orders.customer_id + 10";
let expected = "Projection: person.id, person.age\