You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/09 14:53:27 UTC
[arrow-datafusion] branch master updated: Add support for non-column key for equijoin when eliminating cross join to inner join (#4443)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 097a3de03 Add support for non-column key for equijoin when eliminating cross join to inner join (#4443)
097a3de03 is described below
commit 097a3de03efb4e1fb81e191cf823f21a78937c19
Author: ygf11 <ya...@gmail.com>
AuthorDate: Fri Dec 9 22:53:20 2022 +0800
Add support for non-column key for equijoin when eliminating cross join to inner join (#4443)
* Support non-column join key in eliminating cross join to inner join
* Add comment
* Make clippy happy
* Add tests
* Add alias for cast expr join keys
* Add tests
* Add relative issue comment
* Improve test
* Improve use declarations
---
datafusion/core/tests/sql/joins.rs | 96 +++++++
datafusion/expr/src/lib.rs | 4 +-
datafusion/expr/src/logical_plan/builder.rs | 65 ++++-
datafusion/optimizer/src/eliminate_cross_join.rs | 336 ++++++++++++++++++-----
datafusion/sql/src/planner.rs | 69 ++---
5 files changed, 454 insertions(+), 116 deletions(-)
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 6fbf16c03..68eea7991 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2304,3 +2304,99 @@ async fn error_cross_join() -> Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ // reduce to inner join
+ let sql = "select * from t1 cross join t2 where t1.t1_id + 12 = t2.t2_id + 1";
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Inner Join: t1.t1_id + Int64(12) = t2.t2_id + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int, CAST(t1.t1_id AS Int64) + Int64(12) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Projection: t2.t2_id, t2.t2_name, t2.t2_int, CAST(t2.t2_id AS Int64) + Int64(1) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let expected = vec![
+ "+-------+---------+--------+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+ "+-------+---------+--------+-------+---------+--------+",
+ "| 11 | a | 1 | 22 | y | 1 |",
+ "| 33 | c | 3 | 44 | x | 3 |",
+ "| 44 | d | 4 | 55 | w | 3 |",
+ "+-------+---------+--------+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn reduce_cross_join_with_cast_expr_join_key() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ // reduce to inner join, t2.t2_id will insert cast.
+ let sql =
+ "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = cast(t2.t2_id as BIGINT)";
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx
+ .create_logical_plan(&("explain ".to_owned() + sql))
+ .expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, t2_id:UInt32;N, t1_name:Utf8;N]",
+ " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " Inner Join: t1.t1_id + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N, t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
+ " Projection: t1.t1_id, t1.t1_name, CAST(t1.t1_id AS Int64) + Int64(11) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N]",
+ " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
+ " Projection: t2.t2_id, CAST(t2.t2_id AS Int64) AS CAST(t2.t2_id AS Int64) [t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let expected = vec![
+ "+-------+-------+---------+",
+ "| t1_id | t2_id | t1_name |",
+ "+-------+-------+---------+",
+ "| 11 | 22 | a |",
+ "| 33 | 44 | c |",
+ "| 44 | 55 | d |",
+ "+-------+-------+---------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index e061b1345..3c18b0481 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -67,7 +67,9 @@ pub use function::{
};
pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
pub use logical_plan::{
- builder::{build_join_schema, union, UNNAMED_TABLE},
+ builder::{
+ build_join_schema, union, wrap_projection_for_join_if_necessary, UNNAMED_TABLE,
+ },
Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable,
CreateMemoryTable, CreateView, CrossJoin, Distinct, DropTable, DropView,
EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit,
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 28d3ccc91..d5bc64c40 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -42,8 +42,9 @@ use datafusion_common::{
ToDFSchema,
};
use std::any::Any;
+use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
-use std::{collections::HashMap, sync::Arc};
+use std::sync::Arc;
/// Default table name for unnamed table
pub const UNNAMED_TABLE: &str = "?table?";
@@ -995,6 +996,68 @@ pub fn table_scan(
LogicalPlanBuilder::scan(name.unwrap_or(UNNAMED_TABLE), table_source, projection)
}
+/// Wrap projection for a plan, if the join keys contains normal expression.
+pub fn wrap_projection_for_join_if_necessary(
+ join_keys: &[Expr],
+ input: LogicalPlan,
+) -> Result<(LogicalPlan, Vec<Column>, bool)> {
+ let input_schema = input.schema();
+ let alias_join_keys: Vec<Expr> = join_keys
+ .iter()
+ .map(|key| {
+ // The display_name() of cast expression will ignore the cast info, and show the inner expression name.
+ // If we do not add alais, it will throw same field name error in the schema when adding projection.
+ // For example:
+ // input scan : [a, b, c],
+ // join keys: [cast(a as int)]
+ //
+ // then a and cast(a as int) will use the same field name - `a` in projection schema.
+ // https://github.com/apache/arrow-datafusion/issues/4478
+ if matches!(key, Expr::Cast(_))
+ || matches!(
+ key,
+ Expr::TryCast {
+ expr: _,
+ data_type: _
+ }
+ )
+ {
+ let alias = format!("{:?}", key);
+ key.clone().alias(alias)
+ } else {
+ key.clone()
+ }
+ })
+ .collect::<Vec<_>>();
+
+ let need_project = join_keys.iter().any(|key| !matches!(key, Expr::Column(_)));
+ let plan = if need_project {
+ let mut projection = expand_wildcard(input_schema, &input)?;
+ let join_key_items = alias_join_keys
+ .iter()
+ .flat_map(|expr| expr.try_into_col().is_err().then_some(expr))
+ .cloned()
+ .collect::<HashSet<Expr>>();
+ projection.extend(join_key_items);
+
+ LogicalPlanBuilder::from(input)
+ .project(projection)?
+ .build()?
+ } else {
+ input
+ };
+
+ let join_on = alias_join_keys
+ .into_iter()
+ .map(|key| {
+ key.try_into_col()
+ .or_else(|_| Ok(Column::from_name(key.display_name()?)))
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok((plan, join_on, need_project))
+}
+
/// Basic TableSource implementation intended for use in tests and documentation. It is expected
/// that users will provide their own TableSource implementations or use DataFusion's
/// DefaultTableSource.
diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs
index e255fe921..8ca457771 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -17,21 +17,17 @@
//! Optimizer rule to eliminate cross join to inner join if join predicates are available in filters.
use crate::{utils, OptimizerConfig, OptimizerRule};
-use datafusion_common::{Column, DataFusionError, Result};
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::expr::{BinaryExpr, Expr};
+use datafusion_expr::logical_plan::{
+ CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
+};
+use datafusion_expr::utils::{can_hash, check_all_column_from_schema};
use datafusion_expr::{
- and, build_join_schema,
- expr::BinaryExpr,
- logical_plan::{CrossJoin, Filter, Join, JoinType, LogicalPlan},
- or,
- utils::can_hash,
- Projection,
+ and, build_join_schema, or, wrap_projection_for_join_if_necessary, ExprSchemable,
+ Operator,
};
-use datafusion_expr::{Expr, Operator};
-
use std::collections::HashSet;
-
-//use std::collections::HashMap;
-use datafusion_expr::logical_plan::JoinConstraint;
use std::sync::Arc;
#[derive(Default)]
@@ -64,7 +60,7 @@ impl OptimizerRule for EliminateCrossJoin {
LogicalPlan::Filter(filter) => {
let input = (**filter.input()).clone();
- let mut possible_join_keys: Vec<(Column, Column)> = vec![];
+ let mut possible_join_keys: Vec<(Expr, Expr)> = vec![];
let mut all_inputs: Vec<LogicalPlan> = vec![];
match &input {
LogicalPlan::Join(join) if (join.join_type == JoinType::Inner) => {
@@ -88,9 +84,9 @@ impl OptimizerRule for EliminateCrossJoin {
let predicate = filter.predicate();
// join keys are handled locally
- let mut all_join_keys: HashSet<(Column, Column)> = HashSet::new();
+ let mut all_join_keys: HashSet<(Expr, Expr)> = HashSet::new();
- extract_possible_join_keys(predicate, &mut possible_join_keys);
+ extract_possible_join_keys(predicate, &mut possible_join_keys)?;
let mut left = all_inputs.remove(0);
while !all_inputs.is_empty() {
@@ -103,6 +99,7 @@ impl OptimizerRule for EliminateCrossJoin {
}
left = utils::optimize_children(self, &left, _optimizer_config)?;
+
if plan.schema() != left.schema() {
left = LogicalPlan::Projection(Projection::new_from_schema(
Arc::new(left.clone()),
@@ -139,13 +136,15 @@ impl OptimizerRule for EliminateCrossJoin {
fn flatten_join_inputs(
plan: &LogicalPlan,
- possible_join_keys: &mut Vec<(Column, Column)>,
+ possible_join_keys: &mut Vec<(Expr, Expr)>,
all_inputs: &mut Vec<LogicalPlan>,
) -> Result<()> {
let children = match plan {
LogicalPlan::Join(join) => {
for join_keys in join.on.iter() {
- possible_join_keys.push(join_keys.clone());
+ let join_keys = join_keys.clone();
+ possible_join_keys
+ .push((Expr::Column(join_keys.0), Expr::Column(join_keys.1)));
}
let left = &*(join.left);
let right = &*(join.right);
@@ -182,23 +181,49 @@ fn flatten_join_inputs(
}
fn find_inner_join(
- left: &LogicalPlan,
+ left_input: &LogicalPlan,
rights: &mut Vec<LogicalPlan>,
- possible_join_keys: &mut Vec<(Column, Column)>,
- all_join_keys: &mut HashSet<(Column, Column)>,
+ possible_join_keys: &mut Vec<(Expr, Expr)>,
+ all_join_keys: &mut HashSet<(Expr, Expr)>,
) -> Result<LogicalPlan> {
- for (i, right) in rights.iter().enumerate() {
+ for (i, right_input) in rights.iter().enumerate() {
let mut join_keys = vec![];
for (l, r) in &mut *possible_join_keys {
- if left.schema().field_from_column(l).is_ok()
- && right.schema().field_from_column(r).is_ok()
- && can_hash(left.schema().field_from_column(l).unwrap().data_type())
- {
+ let left_using_columns = l.to_columns()?;
+ let right_using_columns = r.to_columns()?;
+
+ // Conditions like a = 10, will be treated as filter.
+ if left_using_columns.is_empty() || right_using_columns.is_empty() {
+ continue;
+ }
+
+ let l_is_left = check_all_column_from_schema(
+ &left_using_columns,
+ left_input.schema().clone(),
+ )?;
+ let r_is_right = check_all_column_from_schema(
+ &right_using_columns,
+ right_input.schema().clone(),
+ )?;
+
+ let r_is_left_and_l_is_right = || {
+ let result = check_all_column_from_schema(
+ &right_using_columns,
+ left_input.schema().clone(),
+ )? && check_all_column_from_schema(
+ &left_using_columns,
+ right_input.schema().clone(),
+ )?;
+
+ Result::Ok(result)
+ };
+
+ // Save join keys
+ if l_is_left && r_is_right && can_hash(&l.get_type(left_input.schema())?) {
join_keys.push((l.clone(), r.clone()));
- } else if left.schema().field_from_column(r).is_ok()
- && right.schema().field_from_column(l).is_ok()
- && can_hash(left.schema().field_from_column(r).unwrap().data_type())
+ } else if r_is_left_and_l_is_right()?
+ && can_hash(&l.get_type(right_input.schema())?)
{
join_keys.push((r.clone(), l.clone()));
}
@@ -206,18 +231,33 @@ fn find_inner_join(
if !join_keys.is_empty() {
all_join_keys.extend(join_keys.clone());
- let right = rights.remove(i);
+ let right_input = rights.remove(i);
let join_schema = Arc::new(build_join_schema(
- left.schema(),
- right.schema(),
+ left_input.schema(),
+ right_input.schema(),
&JoinType::Inner,
)?);
+
+ // Wrap projection
+ let (left_on, right_on): (Vec<Expr>, Vec<Expr>) =
+ join_keys.into_iter().unzip();
+ let (new_left_input, new_left_on, _) =
+ wrap_projection_for_join_if_necessary(&left_on, left_input.clone())?;
+ let (new_right_input, new_right_on, _) =
+ wrap_projection_for_join_if_necessary(&right_on, right_input)?;
+
+ // Build new join on
+ let join_on = new_left_on
+ .into_iter()
+ .zip(new_right_on.into_iter())
+ .collect::<Vec<_>>();
+
return Ok(LogicalPlan::Join(Join {
- left: Arc::new(left.clone()),
- right: Arc::new(right),
+ left: Arc::new(new_left_input),
+ right: Arc::new(new_right_input),
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
- on: join_keys,
+ on: join_on,
filter: None,
schema: join_schema,
null_equals_null: false,
@@ -226,22 +266,22 @@ fn find_inner_join(
}
let right = rights.remove(0);
let join_schema = Arc::new(build_join_schema(
- left.schema(),
+ left_input.schema(),
right.schema(),
&JoinType::Inner,
)?);
Ok(LogicalPlan::CrossJoin(CrossJoin {
- left: Arc::new(left.clone()),
+ left: Arc::new(left_input.clone()),
right: Arc::new(right),
schema: join_schema,
}))
}
fn intersect(
- accum: &mut Vec<(Column, Column)>,
- vec1: &[(Column, Column)],
- vec2: &[(Column, Column)],
+ accum: &mut Vec<(Expr, Expr)>,
+ vec1: &[(Expr, Expr)],
+ vec2: &[(Expr, Expr)],
) {
for x1 in vec1.iter() {
for x2 in vec2.iter() {
@@ -253,38 +293,35 @@ fn intersect(
}
/// Extract join keys from a WHERE clause
-fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) {
+fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> Result<()> {
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
match op {
Operator::Eq => {
- if let (Expr::Column(l), Expr::Column(r)) =
- (left.as_ref(), right.as_ref())
+ // Ensure that we don't add the same Join keys multiple times
+ if !(accum.contains(&(*left.clone(), *right.clone()))
+ || accum.contains(&(*right.clone(), *left.clone())))
{
- // Ensure that we don't add the same Join keys multiple times
- if !(accum.contains(&(l.clone(), r.clone()))
- || accum.contains(&(r.clone(), l.clone())))
- {
- accum.push((l.clone(), r.clone()));
- }
+ accum.push((*left.clone(), *right.clone()));
}
}
Operator::And => {
- extract_possible_join_keys(left, accum);
- extract_possible_join_keys(right, accum)
+ extract_possible_join_keys(left, accum)?;
+ extract_possible_join_keys(right, accum)?
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::Or => {
let mut left_join_keys = vec![];
let mut right_join_keys = vec![];
- extract_possible_join_keys(left, &mut left_join_keys);
- extract_possible_join_keys(right, &mut right_join_keys);
+ extract_possible_join_keys(left, &mut left_join_keys)?;
+ extract_possible_join_keys(right, &mut right_join_keys)?;
intersect(accum, &left_join_keys, &right_join_keys)
}
_ => (),
- }
+ };
}
+ Ok(())
}
/// Remove join expressions from a filter expression
@@ -292,25 +329,22 @@ fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) {
/// Returns None otherwise
fn remove_join_expressions(
expr: &Expr,
- join_columns: &HashSet<(Column, Column)>,
+ join_keys: &HashSet<(Expr, Expr)>,
) -> Result<Option<Expr>> {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
- Operator::Eq => match (left.as_ref(), right.as_ref()) {
- (Expr::Column(l), Expr::Column(r)) => {
- if join_columns.contains(&(l.clone(), r.clone()))
- || join_columns.contains(&(r.clone(), l.clone()))
- {
- Ok(None)
- } else {
- Ok(Some(expr.clone()))
- }
+ Operator::Eq => {
+ if join_keys.contains(&(*left.clone(), *right.clone()))
+ || join_keys.contains(&(*right.clone(), *left.clone()))
+ {
+ Ok(None)
+ } else {
+ Ok(Some(expr.clone()))
}
- _ => Ok(Some(expr.clone())),
- },
+ }
Operator::And => {
- let l = remove_join_expressions(left, join_columns)?;
- let r = remove_join_expressions(right, join_columns)?;
+ let l = remove_join_expressions(left, join_keys)?;
+ let r = remove_join_expressions(right, join_keys)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(and(ll, rr))),
(Some(ll), _) => Ok(Some(ll)),
@@ -320,8 +354,8 @@ fn remove_join_expressions(
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Operator::Or => {
- let l = remove_join_expressions(left, join_columns)?;
- let r = remove_join_expressions(right, join_columns)?;
+ let l = remove_join_expressions(left, join_keys)?;
+ let r = remove_join_expressions(right, join_keys)?;
match (l, r) {
(Some(ll), Some(rr)) => Ok(Some(or(ll, rr))),
(Some(ll), _) => Ok(Some(ll)),
@@ -1052,4 +1086,168 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn eliminate_cross_join_with_expr_and() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+
+ // could eliminate to inner join since filter has Join predicates
+ let plan = LogicalPlanBuilder::from(t1)
+ .cross_join(&t2)?
+ .filter(binary_expr(
+ (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
+ And,
+ col("t2.c").lt(lit(20u32)),
+ ))?
+ .build()?;
+
+ let expected = vec![
+ "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+ ];
+
+ assert_optimized_plan_eq(&plan, expected);
+
+ Ok(())
+ }
+
+ #[test]
+ fn eliminate_cross_with_expr_or() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+
+ // could not eliminate to inner join since filter OR expression and there is no common
+ // Join predicates in left and right of OR expr.
+ let plan = LogicalPlanBuilder::from(t1)
+ .cross_join(&t2)?
+ .filter(binary_expr(
+ (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
+ Or,
+ col("t2.b").eq(col("t1.a")),
+ ))?
+ .build()?;
+
+ let expected = vec![
+ "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+ ];
+
+ assert_optimized_plan_eq(&plan, expected);
+
+ Ok(())
+ }
+
+ #[test]
+ fn eliminate_cross_with_common_expr_and() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+
+ // could eliminate to inner join
+ let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
+ let plan = LogicalPlanBuilder::from(t1)
+ .cross_join(&t2)?
+ .filter(binary_expr(
+ binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
+ And,
+ binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
+ ))?
+ .build()?;
+
+ let expected = vec![
+ "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+ ];
+
+ assert_optimized_plan_eq(&plan, expected);
+
+ Ok(())
+ }
+
+ #[test]
+ fn eliminate_cross_with_common_expr_or() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+
+ // could eliminate to inner join since Or predicates have common Join predicates
+ let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
+ let plan = LogicalPlanBuilder::from(t1)
+ .cross_join(&t2)?
+ .filter(binary_expr(
+ binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
+ Or,
+ binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
+ ))?
+ .build()?;
+
+ let expected = vec![
+ "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+ ];
+
+ assert_optimized_plan_eq(&plan, expected);
+
+ Ok(())
+ }
+
+ #[test]
+ fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+ let t3 = test_table_scan_with_name("t3")?;
+
+ // could eliminate to inner join
+ let plan = LogicalPlanBuilder::from(t1)
+ .cross_join(&t2)?
+ .cross_join(&t3)?
+ .filter(binary_expr(
+ binary_expr(
+ (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
+ And,
+ col("t3.c").lt(lit(15u32)),
+ ),
+ And,
+ binary_expr(
+ (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
+ And,
+ col("t3.b").lt(lit(15u32)),
+ ),
+ ))?
+ .build()?;
+
+ let expected = vec![
+ "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]",
+ " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]",
+ " Projection: t1.a, t1.b, t1.c, t1.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]",
+ " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
+ " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+ ];
+
+ assert_optimized_plan_eq(&plan, expected);
+
+ Ok(())
+ }
}
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 7a27f40d8..6587aa30d 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -45,7 +45,10 @@ use datafusion_common::{OwnedTableReference, TableReference};
use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like};
use datafusion_expr::expr_rewriter::normalize_col;
use datafusion_expr::expr_rewriter::normalize_col_with_schemas;
-use datafusion_expr::logical_plan::builder::project;
+use datafusion_expr::logical_plan::builder::{
+ project, wrap_projection_for_join_if_necessary,
+};
+use datafusion_expr::logical_plan::Join as HashJoin;
use datafusion_expr::logical_plan::JoinConstraint as HashJoinConstraint;
use datafusion_expr::logical_plan::{
Analyze, CreateCatalog, CreateCatalogSchema,
@@ -53,8 +56,7 @@ use datafusion_expr::logical_plan::{
DropTable, DropView, Explain, JoinType, LogicalPlan, LogicalPlanBuilder,
Partitioning, PlanType, SetVariable, ToStringifiedPlan,
};
-use datafusion_expr::logical_plan::{Filter, Subquery};
-use datafusion_expr::logical_plan::{Join as HashJoin, Prepare};
+use datafusion_expr::logical_plan::{Filter, Prepare, Subquery};
use datafusion_expr::utils::{
can_hash, check_all_column_from_schema, expand_qualified_wildcard, expand_wildcard,
expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_column_exprs,
@@ -862,26 +864,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.build()
} else {
// Wrap projection for left input if left join keys contain normal expression.
- let (left_child, left_projected) =
+ let (left_child, left_join_keys, left_projected) =
wrap_projection_for_join_if_necessary(&left_keys, left)?;
- let left_join_keys = left_keys
- .iter()
- .map(|key| {
- key.try_into_col()
- .or_else(|_| Ok(Column::from_name(key.display_name()?)))
- })
- .collect::<Result<Vec<_>>>()?;
// Wrap projection for right input if right join keys contains normal expression.
- let (right_child, right_projected) =
+ let (right_child, right_join_keys, right_projected) =
wrap_projection_for_join_if_necessary(&right_keys, right)?;
- let right_join_keys = right_keys
- .iter()
- .map(|key| {
- key.try_into_col()
- .or_else(|_| Ok(Column::from_name(key.display_name()?)))
- })
- .collect::<Result<Vec<_>>>()?;
let join_plan_builder = LogicalPlanBuilder::from(left_child).join(
&right_child,
@@ -3228,32 +3216,6 @@ fn extract_join_keys(
Ok(())
}
-/// Wrap projection for a plan, if the join keys contains normal expression.
-fn wrap_projection_for_join_if_necessary(
- join_keys: &[Expr],
- input: LogicalPlan,
-) -> Result<(LogicalPlan, bool)> {
- let expr_join_keys = join_keys
- .iter()
- .flat_map(|expr| expr.try_into_col().is_err().then_some(expr))
- .cloned()
- .collect::<HashSet<Expr>>();
-
- let need_project = !expr_join_keys.is_empty();
- let plan = if need_project {
- let mut projection = vec![Expr::Wildcard];
- projection.extend(expr_join_keys.into_iter());
-
- LogicalPlanBuilder::from(input)
- .project(projection)?
- .build()?
- } else {
- input
- };
-
- Ok((plan, need_project))
-}
-
/// Ensure any column reference of the expression is unambiguous.
/// Assume we have two schema:
/// schema1: a, b ,c
@@ -6432,6 +6394,23 @@ mod tests {
quick_test(sql, expected);
}
+ #[test]
+ fn test_inner_join_with_cast_key() {
+ let sql = "SELECT person.id, person.age
+ FROM person
+ INNER JOIN orders
+ ON cast(person.id as Int) = cast(orders.customer_id as Int)";
+
+ let expected = "Projection: person.id, person.age\
+ \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\
+ \n Inner Join: CAST(person.id AS Int32) = CAST(orders.customer_id AS Int32)\
+ \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, CAST(person.id AS Int32) AS CAST(person.id AS Int32)\
+ \n TableScan: person\
+ \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, CAST(orders.customer_id AS Int32) AS CAST(orders.customer_id AS Int32)\
+ \n TableScan: orders";
+ quick_test(sql, expected);
+ }
+
fn assert_field_not_found(err: DataFusionError, name: &str) {
match err {
DataFusionError::SchemaError { .. } => {