You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/09 14:53:27 UTC

[arrow-datafusion] branch master updated: Add support for non-column key for equijoin when eliminating cross join to inner join (#4443)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 097a3de03 Add support for non-column key for equijoin when eliminating cross join to inner join (#4443)
097a3de03 is described below

commit 097a3de03efb4e1fb81e191cf823f21a78937c19
Author: ygf11 <ya...@gmail.com>
AuthorDate: Fri Dec 9 22:53:20 2022 +0800

    Add support for non-column key for equijoin when eliminating cross join to inner join (#4443)
    
    * Support non-column join key in eliminating cross join to inner join
    
    * Add comment
    
    * Make clippy  happy
    
    * Add tests
    
    * Add alias for cast expr join keys
    
    * Add tests
    
    * Add relative issue comment
    
    * Improve test
    
    * Improve use declarations
---
 datafusion/core/tests/sql/joins.rs               |  96 +++++++
 datafusion/expr/src/lib.rs                       |   4 +-
 datafusion/expr/src/logical_plan/builder.rs      |  65 ++++-
 datafusion/optimizer/src/eliminate_cross_join.rs | 336 ++++++++++++++++++-----
 datafusion/sql/src/planner.rs                    |  69 ++---
 5 files changed, 454 insertions(+), 116 deletions(-)

diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 6fbf16c03..68eea7991 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2304,3 +2304,99 @@ async fn error_cross_join() -> Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        // reduce to inner join
+        let sql = "select * from t1 cross join t2 where t1.t1_id + 12 = t2.t2_id + 1";
+        let msg = format!("Creating logical plan for '{}'", sql);
+        let plan = ctx
+            .create_logical_plan(&("explain ".to_owned() + sql))
+            .expect(&msg);
+        let state = ctx.state();
+        let plan = state.optimize(&plan)?;
+        let expected = vec![
+            "Explain [plan_type:Utf8, plan:Utf8]",
+            "  Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "    Inner Join: t1.t1_id + Int64(12) = t2.t2_id + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]",
+            "      Projection: t1.t1_id, t1.t1_name, t1.t1_int, CAST(t1.t1_id AS Int64) + Int64(12) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N]",
+            "        TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+            "      Projection: t2.t2_id, t2.t2_name, t2.t2_int, CAST(t2.t2_id AS Int64) + Int64(1) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]",
+            "        TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        ];
+        let formatted = plan.display_indent_schema().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+            expected, actual
+        );
+        let expected = vec![
+            "+-------+---------+--------+-------+---------+--------+",
+            "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+            "+-------+---------+--------+-------+---------+--------+",
+            "| 11    | a       | 1      | 22    | y       | 1      |",
+            "| 33    | c       | 3      | 44    | x       | 3      |",
+            "| 44    | d       | 4      | 55    | w       | 3      |",
+            "+-------+---------+--------+-------+---------+--------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn reduce_cross_join_with_cast_expr_join_key() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        // reduce to inner join, t2.t2_id will insert cast.
+        let sql =
+            "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = cast(t2.t2_id as BIGINT)";
+        let msg = format!("Creating logical plan for '{}'", sql);
+        let plan = ctx
+            .create_logical_plan(&("explain ".to_owned() + sql))
+            .expect(&msg);
+        let state = ctx.state();
+        let plan = state.optimize(&plan)?;
+        let expected = vec![
+            "Explain [plan_type:Utf8, plan:Utf8]",
+            "  Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, t2_id:UInt32;N, t1_name:Utf8;N]",
+            "    Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+            "      Inner Join: t1.t1_id + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N, t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
+            "        Projection: t1.t1_id, t1.t1_name, CAST(t1.t1_id AS Int64) + Int64(11) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N]",
+            "          TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
+            "        Projection: t2.t2_id, CAST(t2.t2_id AS Int64) AS CAST(t2.t2_id AS Int64) [t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
+            "          TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+        ];
+        let formatted = plan.display_indent_schema().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+            expected, actual
+        );
+        let expected = vec![
+            "+-------+-------+---------+",
+            "| t1_id | t2_id | t1_name |",
+            "+-------+-------+---------+",
+            "| 11    | 22    | a       |",
+            "| 33    | 44    | c       |",
+            "| 44    | 55    | d       |",
+            "+-------+-------+---------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index e061b1345..3c18b0481 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -67,7 +67,9 @@ pub use function::{
 };
 pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
 pub use logical_plan::{
-    builder::{build_join_schema, union, UNNAMED_TABLE},
+    builder::{
+        build_join_schema, union, wrap_projection_for_join_if_necessary, UNNAMED_TABLE,
+    },
     Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable,
     CreateMemoryTable, CreateView, CrossJoin, Distinct, DropTable, DropView,
     EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit,
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 28d3ccc91..d5bc64c40 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -42,8 +42,9 @@ use datafusion_common::{
     ToDFSchema,
 };
 use std::any::Any;
+use std::collections::{HashMap, HashSet};
 use std::convert::TryFrom;
-use std::{collections::HashMap, sync::Arc};
+use std::sync::Arc;
 
 /// Default table name for unnamed table
 pub const UNNAMED_TABLE: &str = "?table?";
@@ -995,6 +996,68 @@ pub fn table_scan(
     LogicalPlanBuilder::scan(name.unwrap_or(UNNAMED_TABLE), table_source, projection)
 }
 
+/// Wrap projection for a plan, if the join keys contains normal expression.
+pub fn wrap_projection_for_join_if_necessary(
+    join_keys: &[Expr],
+    input: LogicalPlan,
+) -> Result<(LogicalPlan, Vec<Column>, bool)> {
+    let input_schema = input.schema();
+    let alias_join_keys: Vec<Expr> = join_keys
+        .iter()
+        .map(|key| {
+            // The display_name() of cast expression will ignore the cast info, and show the inner expression name.
+            // If we do not add alais, it will throw same field name error in the schema when adding projection.
+            // For example:
+            //    input scan : [a, b, c],
+            //    join keys: [cast(a as int)]
+            //
+            //  then a and cast(a as int) will use the same field name - `a` in projection schema.
+            //  https://github.com/apache/arrow-datafusion/issues/4478
+            if matches!(key, Expr::Cast(_))
+                || matches!(
+                    key,
+                    Expr::TryCast {
+                        expr: _,
+                        data_type: _
+                    }
+                )
+            {
+                let alias = format!("{:?}", key);
+                key.clone().alias(alias)
+            } else {
+                key.clone()
+            }
+        })
+        .collect::<Vec<_>>();
+
+    let need_project = join_keys.iter().any(|key| !matches!(key, Expr::Column(_)));
+    let plan = if need_project {
+        let mut projection = expand_wildcard(input_schema, &input)?;
+        let join_key_items = alias_join_keys
+            .iter()
+            .flat_map(|expr| expr.try_into_col().is_err().then_some(expr))
+            .cloned()
+            .collect::<HashSet<Expr>>();
+        projection.extend(join_key_items);
+
+        LogicalPlanBuilder::from(input)
+            .project(projection)?
+            .build()?
+    } else {
+        input
+    };
+
+    let join_on = alias_join_keys
+        .into_iter()
+        .map(|key| {
+            key.try_into_col()
+                .or_else(|_| Ok(Column::from_name(key.display_name()?)))
+        })
+        .collect::<Result<Vec<_>>>()?;
+
+    Ok((plan, join_on, need_project))
+}
+
 /// Basic TableSource implementation intended for use in tests and documentation. It is expected
 /// that users will provide their own TableSource implementations or use DataFusion's
 /// DefaultTableSource.
diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs
index e255fe921..8ca457771 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -17,21 +17,17 @@
 
 //! Optimizer rule to eliminate cross join to inner join if join predicates are available in filters.
 use crate::{utils, OptimizerConfig, OptimizerRule};
-use datafusion_common::{Column, DataFusionError, Result};
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::expr::{BinaryExpr, Expr};
+use datafusion_expr::logical_plan::{
+    CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
+};
+use datafusion_expr::utils::{can_hash, check_all_column_from_schema};
 use datafusion_expr::{
-    and, build_join_schema,
-    expr::BinaryExpr,
-    logical_plan::{CrossJoin, Filter, Join, JoinType, LogicalPlan},
-    or,
-    utils::can_hash,
-    Projection,
+    and, build_join_schema, or, wrap_projection_for_join_if_necessary, ExprSchemable,
+    Operator,
 };
-use datafusion_expr::{Expr, Operator};
-
 use std::collections::HashSet;
-
-//use std::collections::HashMap;
-use datafusion_expr::logical_plan::JoinConstraint;
 use std::sync::Arc;
 
 #[derive(Default)]
@@ -64,7 +60,7 @@ impl OptimizerRule for EliminateCrossJoin {
             LogicalPlan::Filter(filter) => {
                 let input = (**filter.input()).clone();
 
-                let mut possible_join_keys: Vec<(Column, Column)> = vec![];
+                let mut possible_join_keys: Vec<(Expr, Expr)> = vec![];
                 let mut all_inputs: Vec<LogicalPlan> = vec![];
                 match &input {
                     LogicalPlan::Join(join) if (join.join_type == JoinType::Inner) => {
@@ -88,9 +84,9 @@ impl OptimizerRule for EliminateCrossJoin {
 
                 let predicate = filter.predicate();
                 // join keys are handled locally
-                let mut all_join_keys: HashSet<(Column, Column)> = HashSet::new();
+                let mut all_join_keys: HashSet<(Expr, Expr)> = HashSet::new();
 
-                extract_possible_join_keys(predicate, &mut possible_join_keys);
+                extract_possible_join_keys(predicate, &mut possible_join_keys)?;
 
                 let mut left = all_inputs.remove(0);
                 while !all_inputs.is_empty() {
@@ -103,6 +99,7 @@ impl OptimizerRule for EliminateCrossJoin {
                 }
 
                 left = utils::optimize_children(self, &left, _optimizer_config)?;
+
                 if plan.schema() != left.schema() {
                     left = LogicalPlan::Projection(Projection::new_from_schema(
                         Arc::new(left.clone()),
@@ -139,13 +136,15 @@ impl OptimizerRule for EliminateCrossJoin {
 
 fn flatten_join_inputs(
     plan: &LogicalPlan,
-    possible_join_keys: &mut Vec<(Column, Column)>,
+    possible_join_keys: &mut Vec<(Expr, Expr)>,
     all_inputs: &mut Vec<LogicalPlan>,
 ) -> Result<()> {
     let children = match plan {
         LogicalPlan::Join(join) => {
             for join_keys in join.on.iter() {
-                possible_join_keys.push(join_keys.clone());
+                let join_keys = join_keys.clone();
+                possible_join_keys
+                    .push((Expr::Column(join_keys.0), Expr::Column(join_keys.1)));
             }
             let left = &*(join.left);
             let right = &*(join.right);
@@ -182,23 +181,49 @@ fn flatten_join_inputs(
 }
 
 fn find_inner_join(
-    left: &LogicalPlan,
+    left_input: &LogicalPlan,
     rights: &mut Vec<LogicalPlan>,
-    possible_join_keys: &mut Vec<(Column, Column)>,
-    all_join_keys: &mut HashSet<(Column, Column)>,
+    possible_join_keys: &mut Vec<(Expr, Expr)>,
+    all_join_keys: &mut HashSet<(Expr, Expr)>,
 ) -> Result<LogicalPlan> {
-    for (i, right) in rights.iter().enumerate() {
+    for (i, right_input) in rights.iter().enumerate() {
         let mut join_keys = vec![];
 
         for (l, r) in &mut *possible_join_keys {
-            if left.schema().field_from_column(l).is_ok()
-                && right.schema().field_from_column(r).is_ok()
-                && can_hash(left.schema().field_from_column(l).unwrap().data_type())
-            {
+            let left_using_columns = l.to_columns()?;
+            let right_using_columns = r.to_columns()?;
+
+            // Conditions like a = 10, will be treated as filter.
+            if left_using_columns.is_empty() || right_using_columns.is_empty() {
+                continue;
+            }
+
+            let l_is_left = check_all_column_from_schema(
+                &left_using_columns,
+                left_input.schema().clone(),
+            )?;
+            let r_is_right = check_all_column_from_schema(
+                &right_using_columns,
+                right_input.schema().clone(),
+            )?;
+
+            let r_is_left_and_l_is_right = || {
+                let result = check_all_column_from_schema(
+                    &right_using_columns,
+                    left_input.schema().clone(),
+                )? && check_all_column_from_schema(
+                    &left_using_columns,
+                    right_input.schema().clone(),
+                )?;
+
+                Result::Ok(result)
+            };
+
+            // Save join keys
+            if l_is_left && r_is_right && can_hash(&l.get_type(left_input.schema())?) {
                 join_keys.push((l.clone(), r.clone()));
-            } else if left.schema().field_from_column(r).is_ok()
-                && right.schema().field_from_column(l).is_ok()
-                && can_hash(left.schema().field_from_column(r).unwrap().data_type())
+            } else if r_is_left_and_l_is_right()?
+                && can_hash(&l.get_type(right_input.schema())?)
             {
                 join_keys.push((r.clone(), l.clone()));
             }
@@ -206,18 +231,33 @@ fn find_inner_join(
 
         if !join_keys.is_empty() {
             all_join_keys.extend(join_keys.clone());
-            let right = rights.remove(i);
+            let right_input = rights.remove(i);
             let join_schema = Arc::new(build_join_schema(
-                left.schema(),
-                right.schema(),
+                left_input.schema(),
+                right_input.schema(),
                 &JoinType::Inner,
             )?);
+
+            // Wrap projection
+            let (left_on, right_on): (Vec<Expr>, Vec<Expr>) =
+                join_keys.into_iter().unzip();
+            let (new_left_input, new_left_on, _) =
+                wrap_projection_for_join_if_necessary(&left_on, left_input.clone())?;
+            let (new_right_input, new_right_on, _) =
+                wrap_projection_for_join_if_necessary(&right_on, right_input)?;
+
+            // Build new join on
+            let join_on = new_left_on
+                .into_iter()
+                .zip(new_right_on.into_iter())
+                .collect::<Vec<_>>();
+
             return Ok(LogicalPlan::Join(Join {
-                left: Arc::new(left.clone()),
-                right: Arc::new(right),
+                left: Arc::new(new_left_input),
+                right: Arc::new(new_right_input),
                 join_type: JoinType::Inner,
                 join_constraint: JoinConstraint::On,
-                on: join_keys,
+                on: join_on,
                 filter: None,
                 schema: join_schema,
                 null_equals_null: false,
@@ -226,22 +266,22 @@ fn find_inner_join(
     }
     let right = rights.remove(0);
     let join_schema = Arc::new(build_join_schema(
-        left.schema(),
+        left_input.schema(),
         right.schema(),
         &JoinType::Inner,
     )?);
 
     Ok(LogicalPlan::CrossJoin(CrossJoin {
-        left: Arc::new(left.clone()),
+        left: Arc::new(left_input.clone()),
         right: Arc::new(right),
         schema: join_schema,
     }))
 }
 
 fn intersect(
-    accum: &mut Vec<(Column, Column)>,
-    vec1: &[(Column, Column)],
-    vec2: &[(Column, Column)],
+    accum: &mut Vec<(Expr, Expr)>,
+    vec1: &[(Expr, Expr)],
+    vec2: &[(Expr, Expr)],
 ) {
     for x1 in vec1.iter() {
         for x2 in vec2.iter() {
@@ -253,38 +293,35 @@ fn intersect(
 }
 
 /// Extract join keys from a WHERE clause
-fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) {
+fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> Result<()> {
     if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
         match op {
             Operator::Eq => {
-                if let (Expr::Column(l), Expr::Column(r)) =
-                    (left.as_ref(), right.as_ref())
+                // Ensure that we don't add the same Join keys multiple times
+                if !(accum.contains(&(*left.clone(), *right.clone()))
+                    || accum.contains(&(*right.clone(), *left.clone())))
                 {
-                    // Ensure that we don't add the same Join keys multiple times
-                    if !(accum.contains(&(l.clone(), r.clone()))
-                        || accum.contains(&(r.clone(), l.clone())))
-                    {
-                        accum.push((l.clone(), r.clone()));
-                    }
+                    accum.push((*left.clone(), *right.clone()));
                 }
             }
             Operator::And => {
-                extract_possible_join_keys(left, accum);
-                extract_possible_join_keys(right, accum)
+                extract_possible_join_keys(left, accum)?;
+                extract_possible_join_keys(right, accum)?
             }
             // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
             Operator::Or => {
                 let mut left_join_keys = vec![];
                 let mut right_join_keys = vec![];
 
-                extract_possible_join_keys(left, &mut left_join_keys);
-                extract_possible_join_keys(right, &mut right_join_keys);
+                extract_possible_join_keys(left, &mut left_join_keys)?;
+                extract_possible_join_keys(right, &mut right_join_keys)?;
 
                 intersect(accum, &left_join_keys, &right_join_keys)
             }
             _ => (),
-        }
+        };
     }
+    Ok(())
 }
 
 /// Remove join expressions from a filter expression
@@ -292,25 +329,22 @@ fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) {
 /// Returns None otherwise
 fn remove_join_expressions(
     expr: &Expr,
-    join_columns: &HashSet<(Column, Column)>,
+    join_keys: &HashSet<(Expr, Expr)>,
 ) -> Result<Option<Expr>> {
     match expr {
         Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
-            Operator::Eq => match (left.as_ref(), right.as_ref()) {
-                (Expr::Column(l), Expr::Column(r)) => {
-                    if join_columns.contains(&(l.clone(), r.clone()))
-                        || join_columns.contains(&(r.clone(), l.clone()))
-                    {
-                        Ok(None)
-                    } else {
-                        Ok(Some(expr.clone()))
-                    }
+            Operator::Eq => {
+                if join_keys.contains(&(*left.clone(), *right.clone()))
+                    || join_keys.contains(&(*right.clone(), *left.clone()))
+                {
+                    Ok(None)
+                } else {
+                    Ok(Some(expr.clone()))
                 }
-                _ => Ok(Some(expr.clone())),
-            },
+            }
             Operator::And => {
-                let l = remove_join_expressions(left, join_columns)?;
-                let r = remove_join_expressions(right, join_columns)?;
+                let l = remove_join_expressions(left, join_keys)?;
+                let r = remove_join_expressions(right, join_keys)?;
                 match (l, r) {
                     (Some(ll), Some(rr)) => Ok(Some(and(ll, rr))),
                     (Some(ll), _) => Ok(Some(ll)),
@@ -320,8 +354,8 @@ fn remove_join_expressions(
             }
             // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
             Operator::Or => {
-                let l = remove_join_expressions(left, join_columns)?;
-                let r = remove_join_expressions(right, join_columns)?;
+                let l = remove_join_expressions(left, join_keys)?;
+                let r = remove_join_expressions(right, join_keys)?;
                 match (l, r) {
                     (Some(ll), Some(rr)) => Ok(Some(or(ll, rr))),
                     (Some(ll), _) => Ok(Some(ll)),
@@ -1052,4 +1086,168 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn eliminate_cross_join_with_expr_and() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        // could eliminate to inner join since filter has Join predicates
+        let plan = LogicalPlanBuilder::from(t1)
+            .cross_join(&t2)?
+            .filter(binary_expr(
+                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
+                And,
+                col("t2.c").lt(lit(20u32)),
+            ))?
+            .build()?;
+
+        let expected = vec![
+              "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+              "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+              "    Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+              "      Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
+              "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+              "      Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+              "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+        ];
+
+        assert_optimized_plan_eq(&plan, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn eliminate_cross_with_expr_or() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        // could not eliminate to inner join since filter OR expression and there is no common
+        // Join predicates in left and right of OR expr.
+        let plan = LogicalPlanBuilder::from(t1)
+            .cross_join(&t2)?
+            .filter(binary_expr(
+                (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
+                Or,
+                col("t2.b").eq(col("t1.a")),
+            ))?
+            .build()?;
+
+        let expected = vec![
+              "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+              "  CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+              "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+              "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+        ];
+
+        assert_optimized_plan_eq(&plan, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn eliminate_cross_with_common_expr_and() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        // could eliminate to inner join
+        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
+        let plan = LogicalPlanBuilder::from(t1)
+            .cross_join(&t2)?
+            .filter(binary_expr(
+                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))),
+                And,
+                binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))),
+            ))?
+            .build()?;
+
+        let expected = vec![
+               "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+               "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+               "    Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+               "      Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
+               "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+               "      Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+               "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+        ];
+
+        assert_optimized_plan_eq(&plan, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn eliminate_cross_with_common_expr_or() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        // could eliminate to inner join since Or predicates have common Join predicates
+        let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32));
+        let plan = LogicalPlanBuilder::from(t1)
+            .cross_join(&t2)?
+            .filter(binary_expr(
+                binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))),
+                Or,
+                binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))),
+            ))?
+            .build()?;
+
+        let expected = vec![
+               "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+               "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+               "    Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+               "      Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
+               "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+               "      Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+               "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+       ];
+
+        assert_optimized_plan_eq(&plan, expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn reorder_join_with_expr_key_multi_tables() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+        let t3 = test_table_scan_with_name("t3")?;
+
+        // could eliminate to inner join
+        let plan = LogicalPlanBuilder::from(t1)
+            .cross_join(&t2)?
+            .cross_join(&t3)?
+            .filter(binary_expr(
+                binary_expr(
+                    (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)),
+                    And,
+                    col("t3.c").lt(lit(15u32)),
+                ),
+                And,
+                binary_expr(
+                    (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)),
+                    And,
+                    col("t3.b").lt(lit(15u32)),
+                ),
+            ))?
+            .build()?;
+
+        let expected = vec![
+            "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+            "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+            "    Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+            "      Projection: t1.a, t1.b, t1.c, t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]",
+            "        Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]",
+            "          Projection: t1.a, t1.b, t1.c, t1.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32]",
+            "            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+            "          Projection: t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]",
+            "            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
+            "      Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
+            "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+        ];
+
+        assert_optimized_plan_eq(&plan, expected);
+
+        Ok(())
+    }
 }
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 7a27f40d8..6587aa30d 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -45,7 +45,10 @@ use datafusion_common::{OwnedTableReference, TableReference};
 use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like};
 use datafusion_expr::expr_rewriter::normalize_col;
 use datafusion_expr::expr_rewriter::normalize_col_with_schemas;
-use datafusion_expr::logical_plan::builder::project;
+use datafusion_expr::logical_plan::builder::{
+    project, wrap_projection_for_join_if_necessary,
+};
+use datafusion_expr::logical_plan::Join as HashJoin;
 use datafusion_expr::logical_plan::JoinConstraint as HashJoinConstraint;
 use datafusion_expr::logical_plan::{
     Analyze, CreateCatalog, CreateCatalogSchema,
@@ -53,8 +56,7 @@ use datafusion_expr::logical_plan::{
     DropTable, DropView, Explain, JoinType, LogicalPlan, LogicalPlanBuilder,
     Partitioning, PlanType, SetVariable, ToStringifiedPlan,
 };
-use datafusion_expr::logical_plan::{Filter, Subquery};
-use datafusion_expr::logical_plan::{Join as HashJoin, Prepare};
+use datafusion_expr::logical_plan::{Filter, Prepare, Subquery};
 use datafusion_expr::utils::{
     can_hash, check_all_column_from_schema, expand_qualified_wildcard, expand_wildcard,
     expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_column_exprs,
@@ -862,26 +864,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                         .build()
                 } else {
                     // Wrap projection for left input if left join keys contain normal expression.
-                    let (left_child, left_projected) =
+                    let (left_child, left_join_keys, left_projected) =
                         wrap_projection_for_join_if_necessary(&left_keys, left)?;
-                    let left_join_keys = left_keys
-                        .iter()
-                        .map(|key| {
-                            key.try_into_col()
-                                .or_else(|_| Ok(Column::from_name(key.display_name()?)))
-                        })
-                        .collect::<Result<Vec<_>>>()?;
 
                     // Wrap projection for right input if right join keys contains normal expression.
-                    let (right_child, right_projected) =
+                    let (right_child, right_join_keys, right_projected) =
                         wrap_projection_for_join_if_necessary(&right_keys, right)?;
-                    let right_join_keys = right_keys
-                        .iter()
-                        .map(|key| {
-                            key.try_into_col()
-                                .or_else(|_| Ok(Column::from_name(key.display_name()?)))
-                        })
-                        .collect::<Result<Vec<_>>>()?;
 
                     let join_plan_builder = LogicalPlanBuilder::from(left_child).join(
                         &right_child,
@@ -3228,32 +3216,6 @@ fn extract_join_keys(
     Ok(())
 }
 
-/// Wrap projection for a plan, if the join keys contains normal expression.
-fn wrap_projection_for_join_if_necessary(
-    join_keys: &[Expr],
-    input: LogicalPlan,
-) -> Result<(LogicalPlan, bool)> {
-    let expr_join_keys = join_keys
-        .iter()
-        .flat_map(|expr| expr.try_into_col().is_err().then_some(expr))
-        .cloned()
-        .collect::<HashSet<Expr>>();
-
-    let need_project = !expr_join_keys.is_empty();
-    let plan = if need_project {
-        let mut projection = vec![Expr::Wildcard];
-        projection.extend(expr_join_keys.into_iter());
-
-        LogicalPlanBuilder::from(input)
-            .project(projection)?
-            .build()?
-    } else {
-        input
-    };
-
-    Ok((plan, need_project))
-}
-
 /// Ensure any column reference of the expression is unambiguous.
 /// Assume we have two schema:
 /// schema1: a, b ,c
@@ -6432,6 +6394,23 @@ mod tests {
         quick_test(sql, expected);
     }
 
+    #[test]
+    fn test_inner_join_with_cast_key() {
+        let sql = "SELECT person.id, person.age
+            FROM person
+            INNER JOIN orders
+            ON cast(person.id as Int) = cast(orders.customer_id as Int)";
+
+        let expected = "Projection: person.id, person.age\
+        \n  Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\
+        \n    Inner Join: CAST(person.id AS Int32) = CAST(orders.customer_id AS Int32)\
+        \n      Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, CAST(person.id AS Int32) AS CAST(person.id AS Int32)\
+        \n        TableScan: person\
+        \n      Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, CAST(orders.customer_id AS Int32) AS CAST(orders.customer_id AS Int32)\
+        \n        TableScan: orders";
+        quick_test(sql, expected);
+    }
+
     fn assert_field_not_found(err: DataFusionError, name: &str) {
         match err {
             DataFusionError::SchemaError { .. } => {