You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/15 12:23:50 UTC
[arrow-datafusion] branch master updated: Support non-tuple expression for in-subquery to join (#4826)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new e2daee92c Support non-tuple expression for in-subquery to join (#4826)
e2daee92c is described below
commit e2daee92c5b1c24481ac5903c82aa5bbed1395ef
Author: ygf11 <ya...@gmail.com>
AuthorDate: Sun Jan 15 20:23:45 2023 +0800
Support non-tuple expression for in-subquery to join (#4826)
* Support non-tuple expression for in-subquery to join
* add tests
* add comment and fix cargo fmt
* fix comment
* clean unused comment
* Update datafusion/optimizer/src/decorrelate_where_in.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Update datafusion/optimizer/src/decorrelate_where_in.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Update datafusion/optimizer/src/decorrelate_where_in.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* fix comment
* fix cargo fmt
* add tests
* fix cargo fmt
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion/core/tests/sql/joins.rs | 275 ++++++++++++++
datafusion/core/tests/sql/subqueries.rs | 13 +-
datafusion/expr/src/utils.rs | 5 +-
datafusion/optimizer/src/decorrelate_where_in.rs | 456 ++++++++++++++++++-----
4 files changed, 643 insertions(+), 106 deletions(-)
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index db5c706d3..c20c66e10 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2868,3 +2868,278 @@ async fn test_cross_join_to_groupby_with_different_key_ordering() -> Result<()>
Ok(())
}
+
+#[tokio::test]
+async fn subquery_to_join_with_both_side_expr() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in (select t2.t2_id + 1 from t2)";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 11 | a | 1 |",
+ "| 33 | c | 3 |",
+ "| 44 | d | 4 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn subquery_to_join_with_muti_filter() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
+ (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t2.t2_int > 0)";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]",
+ " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 11 | a | 1 |",
+ "| 33 | c | 3 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn three_projection_exprs_subquery_to_join() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
+ (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name and t2.t2_int > 0)";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+ " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 11 | a | 1 |",
+ "| 33 | c | 3 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
+ (select t2.t2_id + 1 from t2 where t1.t1_int > 0)";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ // The `t1.t1_int > UInt32(0)` should be pushdown by `filter push down rule`.
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 11 | a | 1 |",
+ "| 33 | c | 3 |",
+ "| 44 | d | 4 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn in_subquery_to_join_with_outer_filter() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
+ (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name) and t1.t1_id > 0";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 11 | a | 1 |",
+ "| 33 | c | 3 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn two_in_subquery_to_join_with_outer_filter() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", false)?;
+
+ let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
+ (select t2.t2_id + 1 from t2)
+ and t1.t1_int in(select t2.t2_int + 1 from t2)
+ and t1.t1_id > 0";
+
+ // assert logical plan
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan().unwrap();
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_int AS Int64) = __correlated_sq_2.CAST(t2_int AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+ " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_2 [CAST(t2_int AS Int64) + Int64(1):Int64;N]",
+ " Projection: CAST(t2.t2_int AS Int64) + Int64(1) AS CAST(t2_int AS Int64) + Int64(1) [CAST(t2_int AS Int64) + Int64(1):Int64;N]",
+ " TableScan: t2 projection=[t2_int] [t2_int:UInt32;N]",
+ ];
+
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int |",
+ "+-------+---------+--------+",
+ "| 44 | d | 4 |",
+ "+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+
+ Ok(())
+}
diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
index 2627a2db0..6928e98b7 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -94,12 +94,13 @@ where o_orderstatus in (
let dataframe = ctx.sql(sql).await.unwrap();
let plan = dataframe.into_optimized_plan().unwrap();
let actual = format!("{}", plan.display_indent());
- let expected = r#"Projection: orders.o_orderkey
- LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey
- TableScan: orders projection=[o_orderkey, o_orderstatus]
- SubqueryAlias: __correlated_sq_1
- Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey AS l_orderkey
- TableScan: lineitem projection=[l_orderkey, l_linestatus]"#;
+
+ let expected = "Projection: orders.o_orderkey\
+ \n LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey\
+ \n TableScan: orders projection=[o_orderkey, o_orderstatus]\
+ \n SubqueryAlias: __correlated_sq_1\
+ \n Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey\
+ \n TableScan: lineitem projection=[l_orderkey, l_linestatus]";
assert_eq!(actual, expected);
// assert data
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 682d13321..e84ba0b6f 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -965,7 +965,10 @@ pub fn can_hash(data_type: &DataType) -> bool {
}
/// Check whether all columns are from the schema.
-fn check_all_column_from_schema(columns: &HashSet<Column>, schema: DFSchemaRef) -> bool {
+pub fn check_all_column_from_schema(
+ columns: &HashSet<Column>,
+ schema: DFSchemaRef,
+) -> bool {
columns
.iter()
.all(|column| schema.index_of_column(column).is_ok())
diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs
index 1aa976ce8..13e3acf78 100644
--- a/datafusion/optimizer/src/decorrelate_where_in.rs
+++ b/datafusion/optimizer/src/decorrelate_where_in.rs
@@ -17,15 +17,15 @@
use crate::alias::AliasGenerator;
use crate::optimizer::ApplyOrder;
-use crate::utils::{
- alias_cols, conjunction, exprs_to_join_cols, find_join_exprs, merge_cols,
- only_or_err, split_conjunction, swap_table, verify_not_disjunction,
-};
+use crate::utils::{conjunction, only_or_err, split_conjunction};
use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{context, Result};
+use datafusion_common::{context, Column, Result};
+use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col};
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
+use datafusion_expr::utils::check_all_column_from_schema;
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
use log::debug;
+use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
#[derive(Default)]
@@ -96,6 +96,7 @@ impl OptimizerRule for DecorrelateWhereIn {
return Ok(None);
}
+ // iterate through all exists clauses in predicate, turning each into a join
// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = filter.input.as_ref().clone();
for subquery in subqueries {
@@ -121,81 +122,98 @@ impl OptimizerRule for DecorrelateWhereIn {
}
}
+/// Optimize the where in subquery to left-anti/left-semi join.
+/// If the subquery is a correlated subquery, we need extract the join predicate from the subquery.
+///
+/// For example, given a query like:
+/// `select t1.a, t1.b from t1 where t1 in (select t2.a from t2 where t1.b = t2.b and t1.c > t2.c)`
+///
+/// The optimized plan will be:
+///
+/// ```text
+/// Projection: t1.a, t1.b
+/// LeftSemi Join: Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c
+/// TableScan: t1
+/// SubqueryAlias: __correlated_sq_1
+/// Projection: t2.a AS a, t2.b, t2.c
+/// TableScan: t2
+/// ```
fn optimize_where_in(
query_info: &SubqueryInfo,
- outer_input: &LogicalPlan,
+ left: &LogicalPlan,
outer_other_exprs: &[Expr],
alias: &AliasGenerator,
) -> Result<LogicalPlan> {
- let proj = Projection::try_from_plan(&query_info.query.subquery)
+ let projection = Projection::try_from_plan(&query_info.query.subquery)
.map_err(|e| context!("a projection is required", e))?;
- let mut subqry_input = proj.input.clone();
- let proj = only_or_err(proj.expr.as_slice())
+ let subquery_input = projection.input.clone();
+ let subquery_expr = only_or_err(projection.expr.as_slice())
.map_err(|e| context!("single expression projection required", e))?;
- let subquery_col = proj
- .try_into_col()
- .map_err(|e| context!("single column projection required", e))?;
- let outer_col = query_info
- .where_in_expr
- .try_into_col()
- .map_err(|e| context!("column comparison required", e))?;
-
- // If subquery is correlated, grab necessary information
- let mut subqry_cols = vec![];
- let mut outer_cols = vec![];
- let mut join_filters = None;
- let mut other_subqry_exprs = vec![];
- if let LogicalPlan::Filter(subqry_filter) = (*subqry_input).clone() {
- // split into filters
- let subqry_filter_exprs = split_conjunction(&subqry_filter.predicate);
- verify_not_disjunction(&subqry_filter_exprs)?;
-
- // Grab column names to join on
- let (col_exprs, other_exprs) =
- find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())
- .map_err(|e| context!("column correlation not found", e))?;
- if !col_exprs.is_empty() {
- // it's correlated
- subqry_input = subqry_filter.input.clone();
- (outer_cols, subqry_cols, join_filters) =
- exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)
- .map_err(|e| context!("column correlation not found", e))?;
- other_subqry_exprs = other_exprs;
- }
- }
- let (subqry_cols, outer_cols) =
- merge_cols((&[subquery_col], &subqry_cols), (&[outer_col], &outer_cols));
-
- // build subquery side of join - the thing the subquery was querying
- let subqry_alias = alias.next("__correlated_sq");
- let mut subqry_plan = LogicalPlanBuilder::from((*subqry_input).clone());
- if let Some(expr) = conjunction(other_subqry_exprs) {
- // if the subquery had additional expressions, restore them
- subqry_plan = subqry_plan.filter(expr)?
+ // extract join filters
+ let (join_filters, subquery_input) = extract_join_filters(subquery_input.as_ref())?;
+
+ // in_predicate may be also include in the join filters, remove it from the join filters.
+ let in_predicate = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone());
+ let join_filters = remove_duplicated_filter(join_filters, in_predicate);
+
+ // replace qualified name with subquery alias.
+ let subquery_alias = alias.next("__correlated_sq");
+ let input_schema = subquery_input.schema();
+ let mut subquery_cols: BTreeSet<Column> =
+ join_filters
+ .iter()
+ .try_fold(BTreeSet::new(), |mut cols, expr| {
+ let using_cols: Vec<Column> = expr
+ .to_columns()?
+ .into_iter()
+ .filter(|col| input_schema.field_from_column(col).is_ok())
+ .collect::<_>();
+
+ cols.extend(using_cols);
+ Result::Ok(cols)
+ })?;
+ let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| {
+ replace_qualified_name(filter, &subquery_cols, &subquery_alias).map(Option::Some)
+ })?;
+
+ // add projection
+ if let Expr::Column(col) = subquery_expr {
+ subquery_cols.remove(col);
}
- let projection = alias_cols(&subqry_cols);
- let subqry_plan = subqry_plan
- .project(projection)?
- .alias(&subqry_alias)?
+ let subquery_expr_name = format!("{:?}", unnormalize_col(subquery_expr.clone()));
+ let first_expr = subquery_expr.clone().alias(subquery_expr_name.clone());
+ let projection_exprs: Vec<Expr> = [first_expr]
+ .into_iter()
+ .chain(subquery_cols.into_iter().map(Expr::Column))
+ .collect();
+
+ let right = LogicalPlanBuilder::from(subquery_input)
+ .project(projection_exprs)?
+ .alias(&subquery_alias)?
.build()?;
- debug!("subquery plan:\n{}", subqry_plan.display_indent());
-
- // qualify the join columns for outside the subquery
- let subqry_cols = swap_table(&subqry_alias, &subqry_cols);
- let join_keys = (outer_cols, subqry_cols);
// join our sub query into the main plan
let join_type = match query_info.negated {
true => JoinType::LeftAnti,
false => JoinType::LeftSemi,
};
- let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join(
- subqry_plan,
+ let right_join_col = Column::new(Some(subquery_alias), subquery_expr_name);
+ let in_predicate = Expr::eq(
+ query_info.where_in_expr.clone(),
+ Expr::Column(right_join_col),
+ );
+ let join_filter = join_filter
+ .map(|filter| in_predicate.clone().and(filter))
+ .unwrap_or_else(|| in_predicate);
+
+ let mut new_plan = LogicalPlanBuilder::from(left.clone()).join(
+ right,
join_type,
- join_keys,
- join_filters,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some(join_filter),
)?;
+
if let Some(expr) = conjunction(outer_other_exprs.to_vec()) {
new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them
}
@@ -205,6 +223,72 @@ fn optimize_where_in(
Ok(new_plan)
}
+fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec<Expr>, LogicalPlan)> {
+ if let LogicalPlan::Filter(plan_filter) = maybe_filter {
+ let input_schema = plan_filter.input.schema();
+ let subquery_filter_exprs = split_conjunction(&plan_filter.predicate);
+
+ let mut join_filters: Vec<Expr> = vec![];
+ let mut subquery_filters: Vec<Expr> = vec![];
+ for expr in subquery_filter_exprs {
+ let cols = expr.to_columns()?;
+ if check_all_column_from_schema(&cols, input_schema.clone()) {
+ subquery_filters.push(expr.clone());
+ } else {
+ join_filters.push(expr.clone())
+ }
+ }
+
+ // if the subquery still has filter expressions, restore them.
+ let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone());
+ if let Some(expr) = conjunction(subquery_filters) {
+ plan = plan.filter(expr)?
+ }
+
+ Ok((join_filters, plan.build()?))
+ } else {
+ Ok((vec![], maybe_filter.clone()))
+ }
+}
+
+fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: Expr) -> Vec<Expr> {
+ filters
+ .into_iter()
+ .filter(|filter| {
+ if filter == &in_predicate {
+ return false;
+ }
+
+ // ignore the binary order
+ !match (filter, &in_predicate) {
+ (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => {
+ (a_expr.op == b_expr.op)
+ && (a_expr.left == b_expr.left && a_expr.right == b_expr.right)
+ || (a_expr.left == b_expr.right && a_expr.right == b_expr.left)
+ }
+ _ => false,
+ }
+ })
+ .collect::<Vec<_>>()
+}
+
+fn replace_qualified_name(
+ expr: Expr,
+ cols: &BTreeSet<Column>,
+ subquery_alias: &str,
+) -> Result<Expr> {
+ let alias_cols: Vec<Column> = cols
+ .iter()
+ .map(|col| {
+ Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name))
+ })
+ .collect();
+ let replace_map: HashMap<&Column, &Column> =
+ cols.iter().zip(alias_cols.iter()).collect();
+
+ replace_col(expr, &replace_map)
+}
+
struct SubqueryInfo {
query: Subquery,
where_in_expr: Expr,
@@ -263,8 +347,8 @@ mod tests {
.build()?;
let expected = "Projection: test.b [b:UInt32]\
- \n LeftSemi Join: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\
- \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: sq_1.c AS c [c:UInt32]\
@@ -272,7 +356,6 @@ mod tests {
\n SubqueryAlias: __correlated_sq_2 [c:UInt32]\
\n Projection: sq_2.c AS c [c:UInt32]\
\n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]";
-
assert_optimized_plan_equal(&plan, expected)
}
@@ -293,7 +376,7 @@ mod tests {
let expected = "Projection: test.b [b:UInt32]\
\n Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]\
- \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: sq.c AS c [c:UInt32]\
@@ -347,7 +430,7 @@ mod tests {
\n Subquery: [c:UInt32]\
\n Projection: sq1.c [c:UInt32]\
\n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
- \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: sq2.c AS c [c:UInt32]\
@@ -372,11 +455,11 @@ mod tests {
.build()?;
let expected = "Projection: test.b [b:UInt32]\
- \n LeftSemi Join: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [a:UInt32]\
\n Projection: sq.a AS a [a:UInt32]\
- \n LeftSemi Join: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_2 [c:UInt32]\
\n Projection: sq_nested.c AS c [c:UInt32]\
@@ -401,14 +484,14 @@ mod tests {
.project(vec![col("b")])?
.build()?;
- let expected = "Projection: wrapped.b [b:UInt32]\
+ let expected = "Projection: wrapped.b [b:UInt32]\
\n Filter: wrapped.b < UInt32(30) OR wrapped.c IN (<subquery>) [b:UInt32, c:UInt32]\
\n Subquery: [c:UInt32]\
\n Projection: sq_outer.c [c:UInt32]\
\n TableScan: sq_outer [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: wrapped [b:UInt32, c:UInt32]\
\n Projection: test.b, test.c [b:UInt32, c:UInt32]\
- \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: sq_inner.c AS c [c:UInt32]\
@@ -443,14 +526,16 @@ mod tests {
debug!("plan to optimize:\n{}", plan.display_indent());
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\
- \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelateWhereIn::new()),
&plan,
@@ -486,11 +571,11 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
- \n LeftSemi Join: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\
\n Projection: lineitem.l_orderkey AS l_orderkey [l_orderkey:Int64]\
@@ -524,7 +609,7 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
@@ -554,14 +639,12 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- // Query will fail, but we can still transform the plan
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
- \n Filter: customer.c_custkey = customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelateWhereIn::new()),
@@ -587,7 +670,7 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
@@ -618,7 +701,7 @@ mod tests {
.build()?;
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
@@ -647,11 +730,17 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- // can't optimize on arbitrary expressions (yet)
- assert_optimizer_err(
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelateWhereIn::new()),
&plan,
- "column correlation not found",
+ expected,
);
Ok(())
}
@@ -675,11 +764,19 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- assert_optimizer_err(
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\
+ \n Projection: orders.o_custkey AS o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelateWhereIn::new()),
&plan,
- "Optimizing disjunctions not supported!",
+ expected,
);
+
Ok(())
}
@@ -721,11 +818,17 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- // TODO: support join on expression
- assert_optimizer_err(
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelateWhereIn::new()),
&plan,
- "column comparison required",
+ expected,
);
Ok(())
}
@@ -745,11 +848,17 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;
- // TODO: support join on expressions?
- assert_optimizer_err(
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey + Int32(1):Int64, o_custkey:Int64]\
+ \n Projection: orders.o_custkey + Int32(1) AS o_custkey + Int32(1), orders.o_custkey [o_custkey + Int32(1):Int64, o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelateWhereIn::new()),
&plan,
- "single column projection required",
+ expected,
);
Ok(())
}
@@ -800,7 +909,7 @@ mod tests {
let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
\n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\
- \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
@@ -865,10 +974,10 @@ mod tests {
.build()?;
let expected = "Projection: test.b [b:UInt32]\
- \n LeftSemi Join: test.c = __correlated_sq_1.c, test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\
- \n Projection: sq.c AS c, sq.a AS a [c:UInt32, a:UInt32]\
+ \n Projection: sq.c AS c, sq.a [c:UInt32, a:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_eq_display_indent(
@@ -889,7 +998,7 @@ mod tests {
.build()?;
let expected = "Projection: test.b [b:UInt32]\
- \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: sq.c AS c [c:UInt32]\
@@ -913,7 +1022,7 @@ mod tests {
.build()?;
let expected = "Projection: test.b [b:UInt32]\
- \n LeftAnti Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
\n Projection: sq.c AS c [c:UInt32]\
@@ -926,4 +1035,153 @@ mod tests {
);
Ok(())
}
+
+ #[test]
+ fn in_subquery_both_side_expr() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan = test_table_scan_with_name("sq")?;
+
+ let subquery = LogicalPlanBuilder::from(subquery_scan)
+ .project(vec![col("c") * lit(2u32)])?
+ .build()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "Projection: test.b [b:UInt32]\
+ \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32]\
+ \n Projection: sq.c * UInt32(2) AS c * UInt32(2) [c * UInt32(2):UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn in_subquery_join_filter_and_inner_filter() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan = test_table_scan_with_name("sq")?;
+
+ let subquery = LogicalPlanBuilder::from(subquery_scan)
+ .filter(
+ col("test.a")
+ .eq(col("sq.a"))
+ .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
+ )?
+ .project(vec![col("c") * lit(2u32)])?
+ .build()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "Projection: test.b [b:UInt32]\
+ \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\
+ \n Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a [c * UInt32(2):UInt32, a:UInt32]\
+ \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn in_subquery_muti_project_subquery_cols() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan = test_table_scan_with_name("sq")?;
+
+ let subquery = LogicalPlanBuilder::from(subquery_scan)
+ .filter(
+ col("test.a")
+ .add(col("test.b"))
+ .eq(col("sq.a").add(col("sq.b")))
+ .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))),
+ )?
+ .project(vec![col("c") * lit(2u32)])?
+ .build()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "Projection: test.b [b:UInt32]\
+ \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\
+ \n Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a, sq.b [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\
+ \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn two_in_subquery_with_outer_filter() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan1 = test_table_scan_with_name("sq1")?;
+ let subquery_scan2 = test_table_scan_with_name("sq2")?;
+
+ let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
+ .filter(col("test.a").gt(col("sq1.a")))?
+ .project(vec![col("c") * lit(2u32)])?
+ .build()?;
+
+ let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
+ .filter(col("test.a").gt(col("sq2.a")))?
+ .project(vec![col("c") * lit(2u32)])?
+ .build()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(
+ in_subquery(col("c") + lit(1u32), Arc::new(subquery1)).and(
+ in_subquery(col("c") * lit(2u32), Arc::new(subquery2))
+ .and(col("test.c").gt(lit(1u32))),
+ ),
+ )?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ // Filter: test.c > UInt32(1) happen twice.
+ // issue: https://github.com/apache/arrow-datafusion/issues/4914
+ let expected = "Projection: test.b [b:UInt32]\
+ \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\
+ \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\
+ \n Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c * UInt32(2):UInt32, a:UInt32]\
+ \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_2 [c * UInt32(2):UInt32, a:UInt32]\
+ \n Projection: sq2.c * UInt32(2) AS c * UInt32(2), sq2.a [c * UInt32(2):UInt32, a:UInt32]\
+ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(DecorrelateWhereIn::new()),
+ &plan,
+ expected,
+ );
+ Ok(())
+ }
}