You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/15 12:23:50 UTC

[arrow-datafusion] branch master updated: Support non-tuple expression for in-subquery to join (#4826)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new e2daee92c Support non-tuple expression for in-subquery to join (#4826)
e2daee92c is described below

commit e2daee92c5b1c24481ac5903c82aa5bbed1395ef
Author: ygf11 <ya...@gmail.com>
AuthorDate: Sun Jan 15 20:23:45 2023 +0800

    Support non-tuple expression for in-subquery to join (#4826)
    
    * Support non-tuple expression for in-subquery to join
    
    * add tests
    
    * add comment and fix cargo fmt
    
    * fix comment
    
    * clean unused comment
    
    * Update datafusion/optimizer/src/decorrelate_where_in.rs
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * Update datafusion/optimizer/src/decorrelate_where_in.rs
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * Update datafusion/optimizer/src/decorrelate_where_in.rs
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * fix comment
    
    * fix cargo fmt
    
    * add tests
    
    * fix cargo fmt
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/core/tests/sql/joins.rs               | 275 ++++++++++++++
 datafusion/core/tests/sql/subqueries.rs          |  13 +-
 datafusion/expr/src/utils.rs                     |   5 +-
 datafusion/optimizer/src/decorrelate_where_in.rs | 456 ++++++++++++++++++-----
 4 files changed, 643 insertions(+), 106 deletions(-)

diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index db5c706d3..c20c66e10 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2868,3 +2868,278 @@ async fn test_cross_join_to_groupby_with_different_key_ordering() -> Result<()>
 
     Ok(())
 }
+
+#[tokio::test]
+async fn subquery_to_join_with_both_side_expr() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in (select t2.t2_id + 1 from t2)";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+        "        Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+        "          TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int |",
+        "+-------+---------+--------+",
+        "| 11    | a       | 1      |",
+        "| 33    | c       | 3      |",
+        "| 44    | d       | 4      |",
+        "+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn subquery_to_join_with_muti_filter() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in 
+                         (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t2.t2_int > 0)";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]",
+        "        Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]",
+        "          Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_int:UInt32;N]",
+        "            TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int |",
+        "+-------+---------+--------+",
+        "| 11    | a       | 1      |",
+        "| 33    | c       | 3      |",
+        "+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn three_projection_exprs_subquery_to_join() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in 
+                         (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name and t2.t2_int > 0)";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+        "        Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+        "          Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "            TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int |",
+        "+-------+---------+--------+",
+        "| 11    | a       | 1      |",
+        "| 33    | c       | 3      |",
+        "+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in 
+                         (select t2.t2_id + 1 from t2 where t1.t1_int > 0)";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    // The `t1.t1_int > UInt32(0)` should be pushdown by `filter push down rule`.
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+        "        Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+        "          TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int |",
+        "+-------+---------+--------+",
+        "| 11    | a       | 1      |",
+        "| 33    | c       | 3      |",
+        "| 44    | d       | 4      |",
+        "+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn in_subquery_to_join_with_outer_filter() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in 
+                         (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name) and t1.t1_id > 0";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "        TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+        "        Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+        "          TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int |",
+        "+-------+---------+--------+",
+        "| 11    | a       | 1      |",
+        "| 33    | c       | 3      |",
+        "+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn two_in_subquery_to_join_with_outer_filter() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+    let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in 
+                         (select t2.t2_id + 1 from t2) 
+                         and t1.t1_int in(select t2.t2_int + 1 from t2)
+                         and t1.t1_id > 0";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "    LeftSemi Join: CAST(t1.t1_int AS Int64) = __correlated_sq_2.CAST(t2_int AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "        Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "          TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "        SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+        "          Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+        "            TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+        "      SubqueryAlias: __correlated_sq_2 [CAST(t2_int AS Int64) + Int64(1):Int64;N]",
+        "        Projection: CAST(t2.t2_int AS Int64) + Int64(1) AS CAST(t2_int AS Int64) + Int64(1) [CAST(t2_int AS Int64) + Int64(1):Int64;N]",
+        "          TableScan: t2 projection=[t2_int] [t2_int:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+---------+--------+",
+        "| t1_id | t1_name | t1_int |",
+        "+-------+---------+--------+",
+        "| 44    | d       | 4      |",
+        "+-------+---------+--------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
index 2627a2db0..6928e98b7 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -94,12 +94,13 @@ where o_orderstatus in (
     let dataframe = ctx.sql(sql).await.unwrap();
     let plan = dataframe.into_optimized_plan().unwrap();
     let actual = format!("{}", plan.display_indent());
-    let expected = r#"Projection: orders.o_orderkey
-  LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey
-    TableScan: orders projection=[o_orderkey, o_orderstatus]
-    SubqueryAlias: __correlated_sq_1
-      Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey AS l_orderkey
-        TableScan: lineitem projection=[l_orderkey, l_linestatus]"#;
+
+    let expected = "Projection: orders.o_orderkey\
+    \n  LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey\
+    \n    TableScan: orders projection=[o_orderkey, o_orderstatus]\
+    \n    SubqueryAlias: __correlated_sq_1\
+    \n      Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey\
+    \n        TableScan: lineitem projection=[l_orderkey, l_linestatus]";
     assert_eq!(actual, expected);
 
     // assert data
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 682d13321..e84ba0b6f 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -965,7 +965,10 @@ pub fn can_hash(data_type: &DataType) -> bool {
 }
 
 /// Check whether all columns are from the schema.
-fn check_all_column_from_schema(columns: &HashSet<Column>, schema: DFSchemaRef) -> bool {
+pub fn check_all_column_from_schema(
+    columns: &HashSet<Column>,
+    schema: DFSchemaRef,
+) -> bool {
     columns
         .iter()
         .all(|column| schema.index_of_column(column).is_ok())
diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs
index 1aa976ce8..13e3acf78 100644
--- a/datafusion/optimizer/src/decorrelate_where_in.rs
+++ b/datafusion/optimizer/src/decorrelate_where_in.rs
@@ -17,15 +17,15 @@
 
 use crate::alias::AliasGenerator;
 use crate::optimizer::ApplyOrder;
-use crate::utils::{
-    alias_cols, conjunction, exprs_to_join_cols, find_join_exprs, merge_cols,
-    only_or_err, split_conjunction, swap_table, verify_not_disjunction,
-};
+use crate::utils::{conjunction, only_or_err, split_conjunction};
 use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{context, Result};
+use datafusion_common::{context, Column, Result};
+use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col};
 use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
+use datafusion_expr::utils::check_all_column_from_schema;
 use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
 use log::debug;
+use std::collections::{BTreeSet, HashMap};
 use std::sync::Arc;
 
 #[derive(Default)]
@@ -96,6 +96,7 @@ impl OptimizerRule for DecorrelateWhereIn {
                     return Ok(None);
                 }
 
+                // iterate through all exists clauses in predicate, turning each into a join
                 // iterate through all exists clauses in predicate, turning each into a join
                 let mut cur_input = filter.input.as_ref().clone();
                 for subquery in subqueries {
@@ -121,81 +122,98 @@ impl OptimizerRule for DecorrelateWhereIn {
     }
 }
 
+/// Optimize the where in subquery to left-anti/left-semi join.
+/// If the subquery is a correlated subquery, we need extract the join predicate from the subquery.
+///
+/// For example, given a query like:
+/// `select t1.a, t1.b from t1 where t1 in (select t2.a from t2 where t1.b = t2.b and t1.c > t2.c)`
+///
+/// The optimized plan will be:
+///
+/// ```text
+/// Projection: t1.a, t1.b
+///   LeftSemi Join:  Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c
+///     TableScan: t1
+///     SubqueryAlias: __correlated_sq_1
+///       Projection: t2.a AS a, t2.b, t2.c
+///         TableScan: t2
+/// ```
 fn optimize_where_in(
     query_info: &SubqueryInfo,
-    outer_input: &LogicalPlan,
+    left: &LogicalPlan,
     outer_other_exprs: &[Expr],
     alias: &AliasGenerator,
 ) -> Result<LogicalPlan> {
-    let proj = Projection::try_from_plan(&query_info.query.subquery)
+    let projection = Projection::try_from_plan(&query_info.query.subquery)
         .map_err(|e| context!("a projection is required", e))?;
-    let mut subqry_input = proj.input.clone();
-    let proj = only_or_err(proj.expr.as_slice())
+    let subquery_input = projection.input.clone();
+    let subquery_expr = only_or_err(projection.expr.as_slice())
         .map_err(|e| context!("single expression projection required", e))?;
-    let subquery_col = proj
-        .try_into_col()
-        .map_err(|e| context!("single column projection required", e))?;
-    let outer_col = query_info
-        .where_in_expr
-        .try_into_col()
-        .map_err(|e| context!("column comparison required", e))?;
-
-    // If subquery is correlated, grab necessary information
-    let mut subqry_cols = vec![];
-    let mut outer_cols = vec![];
-    let mut join_filters = None;
-    let mut other_subqry_exprs = vec![];
-    if let LogicalPlan::Filter(subqry_filter) = (*subqry_input).clone() {
-        // split into filters
-        let subqry_filter_exprs = split_conjunction(&subqry_filter.predicate);
-        verify_not_disjunction(&subqry_filter_exprs)?;
-
-        // Grab column names to join on
-        let (col_exprs, other_exprs) =
-            find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())
-                .map_err(|e| context!("column correlation not found", e))?;
-        if !col_exprs.is_empty() {
-            // it's correlated
-            subqry_input = subqry_filter.input.clone();
-            (outer_cols, subqry_cols, join_filters) =
-                exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)
-                    .map_err(|e| context!("column correlation not found", e))?;
-            other_subqry_exprs = other_exprs;
-        }
-    }
 
-    let (subqry_cols, outer_cols) =
-        merge_cols((&[subquery_col], &subqry_cols), (&[outer_col], &outer_cols));
-
-    // build subquery side of join - the thing the subquery was querying
-    let subqry_alias = alias.next("__correlated_sq");
-    let mut subqry_plan = LogicalPlanBuilder::from((*subqry_input).clone());
-    if let Some(expr) = conjunction(other_subqry_exprs) {
-        // if the subquery had additional expressions, restore them
-        subqry_plan = subqry_plan.filter(expr)?
+    // extract join filters
+    let (join_filters, subquery_input) = extract_join_filters(subquery_input.as_ref())?;
+
+    // in_predicate may be also include in the join filters, remove it from the join filters.
+    let in_predicate = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone());
+    let join_filters = remove_duplicated_filter(join_filters, in_predicate);
+
+    // replace qualified name with subquery alias.
+    let subquery_alias = alias.next("__correlated_sq");
+    let input_schema = subquery_input.schema();
+    let mut subquery_cols: BTreeSet<Column> =
+        join_filters
+            .iter()
+            .try_fold(BTreeSet::new(), |mut cols, expr| {
+                let using_cols: Vec<Column> = expr
+                    .to_columns()?
+                    .into_iter()
+                    .filter(|col| input_schema.field_from_column(col).is_ok())
+                    .collect::<_>();
+
+                cols.extend(using_cols);
+                Result::Ok(cols)
+            })?;
+    let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| {
+        replace_qualified_name(filter, &subquery_cols, &subquery_alias).map(Option::Some)
+    })?;
+
+    // add projection
+    if let Expr::Column(col) = subquery_expr {
+        subquery_cols.remove(col);
     }
-    let projection = alias_cols(&subqry_cols);
-    let subqry_plan = subqry_plan
-        .project(projection)?
-        .alias(&subqry_alias)?
+    let subquery_expr_name = format!("{:?}", unnormalize_col(subquery_expr.clone()));
+    let first_expr = subquery_expr.clone().alias(subquery_expr_name.clone());
+    let projection_exprs: Vec<Expr> = [first_expr]
+        .into_iter()
+        .chain(subquery_cols.into_iter().map(Expr::Column))
+        .collect();
+
+    let right = LogicalPlanBuilder::from(subquery_input)
+        .project(projection_exprs)?
+        .alias(&subquery_alias)?
         .build()?;
-    debug!("subquery plan:\n{}", subqry_plan.display_indent());
-
-    // qualify the join columns for outside the subquery
-    let subqry_cols = swap_table(&subqry_alias, &subqry_cols);
-    let join_keys = (outer_cols, subqry_cols);
 
     // join our sub query into the main plan
     let join_type = match query_info.negated {
         true => JoinType::LeftAnti,
         false => JoinType::LeftSemi,
     };
-    let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join(
-        subqry_plan,
+    let right_join_col = Column::new(Some(subquery_alias), subquery_expr_name);
+    let in_predicate = Expr::eq(
+        query_info.where_in_expr.clone(),
+        Expr::Column(right_join_col),
+    );
+    let join_filter = join_filter
+        .map(|filter| in_predicate.clone().and(filter))
+        .unwrap_or_else(|| in_predicate);
+
+    let mut new_plan = LogicalPlanBuilder::from(left.clone()).join(
+        right,
         join_type,
-        join_keys,
-        join_filters,
+        (Vec::<Column>::new(), Vec::<Column>::new()),
+        Some(join_filter),
     )?;
+
     if let Some(expr) = conjunction(outer_other_exprs.to_vec()) {
         new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them
     }
@@ -205,6 +223,72 @@ fn optimize_where_in(
     Ok(new_plan)
 }
 
+fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec<Expr>, LogicalPlan)> {
+    if let LogicalPlan::Filter(plan_filter) = maybe_filter {
+        let input_schema = plan_filter.input.schema();
+        let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
+
+        let mut join_filters: Vec<Expr> = vec![];
+        let mut subquery_filters: Vec<Expr> = vec![];
+        for expr in subquery_filter_exprs {
+            let cols = expr.to_columns()?;
+            if check_all_column_from_schema(&cols, input_schema.clone()) {
+                subquery_filters.push(expr.clone());
+            } else {
+                join_filters.push(expr.clone())
+            }
+        }
+
+        // if the subquery still has filter expressions, restore them.
+        let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone());
+        if let Some(expr) = conjunction(subquery_filters) {
+            plan = plan.filter(expr)?
+        }
+
+        Ok((join_filters, plan.build()?))
+    } else {
+        Ok((vec![], maybe_filter.clone()))
+    }
+}
+
+fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: Expr) -> Vec<Expr> {
+    filters
+        .into_iter()
+        .filter(|filter| {
+            if filter == &in_predicate {
+                return false;
+            }
+
+            // ignore the binary order
+            !match (filter, &in_predicate) {
+                (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => {
+                    (a_expr.op == b_expr.op)
+                        && (a_expr.left == b_expr.left && a_expr.right == b_expr.right)
+                        || (a_expr.left == b_expr.right && a_expr.right == b_expr.left)
+                }
+                _ => false,
+            }
+        })
+        .collect::<Vec<_>>()
+}
+
+fn replace_qualified_name(
+    expr: Expr,
+    cols: &BTreeSet<Column>,
+    subquery_alias: &str,
+) -> Result<Expr> {
+    let alias_cols: Vec<Column> = cols
+        .iter()
+        .map(|col| {
+            Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name))
+        })
+        .collect();
+    let replace_map: HashMap<&Column, &Column> =
+        cols.iter().zip(alias_cols.iter()).collect();
+
+    replace_col(expr, &replace_map)
+}
+
 struct SubqueryInfo {
     query: Subquery,
     where_in_expr: Expr,
@@ -263,8 +347,8 @@ mod tests {
             .build()?;
 
         let expected = "Projection: test.b [b:UInt32]\
-        \n  LeftSemi Join: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\
-        \n    LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+        \n  LeftSemi Join:  Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\
+        \n    LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
         \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
         \n      SubqueryAlias: __correlated_sq_1 [c:UInt32]\
         \n        Projection: sq_1.c AS c [c:UInt32]\
@@ -272,7 +356,6 @@ mod tests {
         \n    SubqueryAlias: __correlated_sq_2 [c:UInt32]\
         \n      Projection: sq_2.c AS c [c:UInt32]\
         \n        TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]";
-
         assert_optimized_plan_equal(&plan, expected)
     }
 
@@ -293,7 +376,7 @@ mod tests {
 
         let expected = "Projection: test.b [b:UInt32]\
         \n  Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]\
-        \n    LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+        \n    LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
         \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
         \n      SubqueryAlias: __correlated_sq_1 [c:UInt32]\
         \n        Projection: sq.c AS c [c:UInt32]\
@@ -347,7 +430,7 @@ mod tests {
         \n    Subquery: [c:UInt32]\
         \n      Projection: sq1.c [c:UInt32]\
         \n        TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
-        \n    LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+        \n    LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
         \n      TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
         \n      SubqueryAlias: __correlated_sq_1 [c:UInt32]\
         \n        Projection: sq2.c AS c [c:UInt32]\
@@ -372,11 +455,11 @@ mod tests {
             .build()?;
 
         let expected = "Projection: test.b [b:UInt32]\
-        \n  LeftSemi Join: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+        \n  LeftSemi Join:  Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
         \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
         \n    SubqueryAlias: __correlated_sq_1 [a:UInt32]\
         \n      Projection: sq.a AS a [a:UInt32]\
-        \n        LeftSemi Join: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\
+        \n        LeftSemi Join:  Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\
         \n          TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\
         \n          SubqueryAlias: __correlated_sq_2 [c:UInt32]\
         \n            Projection: sq_nested.c AS c [c:UInt32]\
@@ -401,14 +484,14 @@ mod tests {
             .project(vec![col("b")])?
             .build()?;
 
-        let expected =  "Projection: wrapped.b [b:UInt32]\
+        let expected = "Projection: wrapped.b [b:UInt32]\
         \n  Filter: wrapped.b < UInt32(30) OR wrapped.c IN (<subquery>) [b:UInt32, c:UInt32]\
         \n    Subquery: [c:UInt32]\
         \n      Projection: sq_outer.c [c:UInt32]\
         \n        TableScan: sq_outer [a:UInt32, b:UInt32, c:UInt32]\
         \n    SubqueryAlias: wrapped [b:UInt32, c:UInt32]\
         \n      Projection: test.b, test.c [b:UInt32, c:UInt32]\
-        \n        LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+        \n        LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
         \n          TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
         \n          SubqueryAlias: __correlated_sq_1 [c:UInt32]\
         \n            Projection: sq_inner.c AS c [c:UInt32]\
@@ -443,14 +526,16 @@ mod tests {
         debug!("plan to optimize:\n{}", plan.display_indent());
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  LeftSemi Join: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\
-        \n    LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n    LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
         \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
         \n      SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
         \n        Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
         \n          TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
-        \n    SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\n      Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+        \n    SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\
+        \n      Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
         \n        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
         assert_optimized_plan_eq_display_indent(
             Arc::new(DecorrelateWhereIn::new()),
             &plan,
@@ -486,11 +571,11 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
         \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
         \n    SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
         \n      Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
-        \n        LeftSemi Join: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+        \n        LeftSemi Join:  Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
         \n          TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
         \n          SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\
         \n            Projection: lineitem.l_orderkey AS l_orderkey [l_orderkey:Int64]\
@@ -524,7 +609,7 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
         \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
         \n    SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
         \n      Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
@@ -554,14 +639,12 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        // Query will fail, but we can still transform the plan
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
         \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
         \n    SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
         \n      Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
-        \n        Filter: customer.c_custkey = customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
-        \n          TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+        \n        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
 
         assert_optimized_plan_eq_display_indent(
             Arc::new(DecorrelateWhereIn::new()),
@@ -587,7 +670,7 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
         \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
         \n    SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
         \n      Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
@@ -618,7 +701,7 @@ mod tests {
             .build()?;
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
-        \n  LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
         \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
         \n    SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
         \n      Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
@@ -647,11 +730,17 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        // can't optimize on arbitrary expressions (yet)
-        assert_optimizer_err(
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+        \n    SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+        \n      Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+        \n        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+        assert_optimized_plan_eq_display_indent(
             Arc::new(DecorrelateWhereIn::new()),
             &plan,
-            "column correlation not found",
+            expected,
         );
         Ok(())
     }
@@ -675,11 +764,19 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        assert_optimizer_err(
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]\
+        \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+        \n    SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\
+        \n      Projection: orders.o_custkey AS o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\
+        \n        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+        assert_optimized_plan_eq_display_indent(
             Arc::new(DecorrelateWhereIn::new()),
             &plan,
-            "Optimizing disjunctions not supported!",
+            expected,
         );
+
         Ok(())
     }
 
@@ -721,11 +818,17 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        // TODO: support join on expression
-        assert_optimizer_err(
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+        \n    SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+        \n      Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+        \n        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+        assert_optimized_plan_eq_display_indent(
             Arc::new(DecorrelateWhereIn::new()),
             &plan,
-            "column comparison required",
+            expected,
         );
         Ok(())
     }
@@ -745,11 +848,17 @@ mod tests {
             .project(vec![col("customer.c_custkey")])?
             .build()?;
 
-        // TODO: support join on expressions?
-        assert_optimizer_err(
+        let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+        \n  LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n    TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+        \n    SubqueryAlias: __correlated_sq_1 [o_custkey + Int32(1):Int64, o_custkey:Int64]\
+        \n      Projection: orders.o_custkey + Int32(1) AS o_custkey + Int32(1), orders.o_custkey [o_custkey + Int32(1):Int64, o_custkey:Int64]\
+        \n        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+        assert_optimized_plan_eq_display_indent(
             Arc::new(DecorrelateWhereIn::new()),
             &plan,
-            "single column projection required",
+            expected,
         );
         Ok(())
     }
@@ -800,7 +909,7 @@ mod tests {
 
         let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
         \n  Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\
-        \n    LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+        \n    LeftSemi Join:  Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
         \n      TableScan: customer [c_custkey:Int64, c_name:Utf8]\
         \n      SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
         \n        Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
@@ -865,10 +974,10 @@ mod tests {
             .build()?;
 
         let expected = "Projection: test.b [b:UInt32]\
-        \n  LeftSemi Join: test.c = __correlated_sq_1.c, test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+        \n  LeftSemi Join:  Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
         \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
         \n    SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\
-        \n      Projection: sq.c AS c, sq.a AS a [c:UInt32, a:UInt32]\
+        \n      Projection: sq.c AS c, sq.a [c:UInt32, a:UInt32]\
         \n        TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
 
         assert_optimized_plan_eq_display_indent(
@@ -889,7 +998,7 @@ mod tests {
             .build()?;
 
         let expected = "Projection: test.b [b:UInt32]\
-        \n  LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+        \n  LeftSemi Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
         \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
         \n    SubqueryAlias: __correlated_sq_1 [c:UInt32]\
         \n      Projection: sq.c AS c [c:UInt32]\
@@ -913,7 +1022,7 @@ mod tests {
             .build()?;
 
         let expected = "Projection: test.b [b:UInt32]\
-        \n  LeftAnti Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+        \n  LeftAnti Join:  Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
         \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
         \n    SubqueryAlias: __correlated_sq_1 [c:UInt32]\
         \n      Projection: sq.c AS c [c:UInt32]\
@@ -926,4 +1035,153 @@ mod tests {
         );
         Ok(())
     }
+
+    #[test]
+    fn in_subquery_both_side_expr() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let subquery_scan = test_table_scan_with_name("sq")?;
+
+        let subquery = LogicalPlanBuilder::from(subquery_scan)
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        let expected = "Projection: test.b [b:UInt32]\
+        \n  LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
+        \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+        \n    SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32]\
+        \n      Projection: sq.c * UInt32(2) AS c * UInt32(2) [c * UInt32(2):UInt32]\
+        \n        TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq_display_indent(
+            Arc::new(DecorrelateWhereIn::new()),
+            &plan,
+            expected,
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn in_subquery_join_filter_and_inner_filter() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let subquery_scan = test_table_scan_with_name("sq")?;
+
+        let subquery = LogicalPlanBuilder::from(subquery_scan)
+            .filter(
+                col("test.a")
+                    .eq(col("sq.a"))
+                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
+            )?
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        let expected = "Projection: test.b [b:UInt32]\
+        \n  LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+        \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+        \n    SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\
+        \n      Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a [c * UInt32(2):UInt32, a:UInt32]\
+        \n        Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\
+        \n          TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq_display_indent(
+            Arc::new(DecorrelateWhereIn::new()),
+            &plan,
+            expected,
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn in_subquery_muti_project_subquery_cols() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let subquery_scan = test_table_scan_with_name("sq")?;
+
+        let subquery = LogicalPlanBuilder::from(subquery_scan)
+            .filter(
+                col("test.a")
+                    .add(col("test.b"))
+                    .eq(col("sq.a").add(col("sq.b")))
+                    .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
+            )?
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        let expected = "Projection: test.b [b:UInt32]\
+        \n  LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]\
+        \n    TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+        \n    SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\
+        \n      Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a, sq.b [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\
+        \n        Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\
+        \n          TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq_display_indent(
+            Arc::new(DecorrelateWhereIn::new()),
+            &plan,
+            expected,
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn two_in_subquery_with_outer_filter() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let subquery_scan1 = test_table_scan_with_name("sq1")?;
+        let subquery_scan2 = test_table_scan_with_name("sq2")?;
+
+        let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
+            .filter(col("test.a").gt(col("sq1.a")))?
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
+            .filter(col("test.a").gt(col("sq2.a")))?
+            .project(vec![col("c") * lit(2u32)])?
+            .build()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(
+                in_subquery(col("c") + lit(1u32), Arc::new(subquery1)).and(
+                    in_subquery(col("c") * lit(2u32), Arc::new(subquery2))
+                        .and(col("test.c").gt(lit(1u32))),
+                ),
+            )?
+            .project(vec![col("test.b")])?
+            .build()?;
+
+        // Filter: test.c > UInt32(1) happen twice.
+        // issue: https://github.com/apache/arrow-datafusion/issues/4914
+        let expected = "Projection: test.b [b:UInt32]\
+        \n  Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
+        \n    LeftSemi Join:  Filter: test.c * UInt32(2) = __correlated_sq_2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\
+        \n      Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
+        \n        LeftSemi Join:  Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+        \n          TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+        \n          SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\
+        \n            Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c * UInt32(2):UInt32, a:UInt32]\
+        \n              TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
+        \n      SubqueryAlias: __correlated_sq_2 [c * UInt32(2):UInt32, a:UInt32]\
+        \n        Projection: sq2.c * UInt32(2) AS c * UInt32(2), sq2.a [c * UInt32(2):UInt32, a:UInt32]\
+        \n          TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq_display_indent(
+            Arc::new(DecorrelateWhereIn::new()),
+            &plan,
+            expected,
+        );
+        Ok(())
+    }
 }