You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/05/15 11:07:43 UTC
[arrow-datafusion] branch main updated: Combine the two rules: DecorrelateWhereExists and DecorrelateWhereIn (#6271)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ca7f760cc9 Combine the two rules: DecorrelateWhereExists and DecorrelateWhereIn (#6271)
ca7f760cc9 is described below
commit ca7f760cc9421b3fc7f1abfae8da578eee5e0736
Author: Ken, Wang <mi...@gmail.com>
AuthorDate: Mon May 15 19:07:35 2023 +0800
Combine the two rules: DecorrelateWhereExists and DecorrelateWhereIn (#6271)
* combine ecorrelation rules
* Fix UTs
* modify TPCH plan txt
* fix intg test
---
benchmarks/expected-plans/q16.txt | 2 +-
benchmarks/expected-plans/q18.txt | 2 +-
benchmarks/expected-plans/q20.txt | 4 +-
benchmarks/expected-plans/q21.txt | 18 +-
benchmarks/expected-plans/q22.txt | 5 +-
benchmarks/expected-plans/q4.txt | 9 +-
datafusion/core/tests/sql/joins.rs | 217 ++---
datafusion/core/tests/sql/subqueries.rs | 34 +-
...ere_in.rs => decorrelate_predicate_subquery.rs} | 979 ++++++++++++++++++---
.../optimizer/src/decorrelate_where_exists.rs | 781 ----------------
datafusion/optimizer/src/lib.rs | 3 +-
datafusion/optimizer/src/optimizer.rs | 6 +-
datafusion/optimizer/src/utils.rs | 2 +-
datafusion/optimizer/tests/integration-test.rs | 24 +-
14 files changed, 1028 insertions(+), 1058 deletions(-)
diff --git a/benchmarks/expected-plans/q16.txt b/benchmarks/expected-plans/q16.txt
index 435f8ab6a8..7cf11f7320 100644
--- a/benchmarks/expected-plans/q16.txt
+++ b/benchmarks/expected-plans/q16.txt
@@ -12,7 +12,7 @@
| | Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) |
| | TableScan: part projection=[p_partkey, p_brand, p_type, p_size] |
| | SubqueryAlias: __correlated_sq_1 |
-| | Projection: supplier.s_suppkey AS s_suppkey |
+| | Projection: supplier.s_suppkey |
| | Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") |
| | TableScan: supplier projection=[s_suppkey, s_comment] |
| physical_plan | SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] |
diff --git a/benchmarks/expected-plans/q18.txt b/benchmarks/expected-plans/q18.txt
index f403e84406..551ac99106 100644
--- a/benchmarks/expected-plans/q18.txt
+++ b/benchmarks/expected-plans/q18.txt
@@ -12,7 +12,7 @@
| | TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice, o_orderdate] |
| | TableScan: lineitem projection=[l_orderkey, l_quantity] |
| | SubqueryAlias: __correlated_sq_1 |
-| | Projection: lineitem.l_orderkey AS l_orderkey |
+| | Projection: lineitem.l_orderkey |
| | Filter: SUM(lineitem.l_quantity) > Decimal128(Some(30000),25,2) |
| | Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_quantity)]] |
| | TableScan: lineitem projection=[l_orderkey, l_quantity] |
diff --git a/benchmarks/expected-plans/q20.txt b/benchmarks/expected-plans/q20.txt
index 03ad420c58..41f2dac583 100644
--- a/benchmarks/expected-plans/q20.txt
+++ b/benchmarks/expected-plans/q20.txt
@@ -11,12 +11,12 @@
| | Filter: nation.n_name = Utf8("CANADA") |
| | TableScan: nation projection=[n_nationkey, n_name] |
| | SubqueryAlias: __correlated_sq_1 |
-| | Projection: partsupp.ps_suppkey AS ps_suppkey |
+| | Projection: partsupp.ps_suppkey |
| | Inner Join: partsupp.ps_partkey = __scalar_sq_1.l_partkey, partsupp.ps_suppkey = __scalar_sq_1.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_1.__value |
| | LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey |
| | TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] |
| | SubqueryAlias: __correlated_sq_2 |
-| | Projection: part.p_partkey AS p_partkey |
+| | Projection: part.p_partkey |
| | Filter: part.p_name LIKE Utf8("forest%") |
| | TableScan: part projection=[p_partkey, p_name] |
| | SubqueryAlias: __scalar_sq_1 |
diff --git a/benchmarks/expected-plans/q21.txt b/benchmarks/expected-plans/q21.txt
index 47285e1c16..67605f5c44 100644
--- a/benchmarks/expected-plans/q21.txt
+++ b/benchmarks/expected-plans/q21.txt
@@ -5,8 +5,8 @@
| | Projection: supplier.s_name, COUNT(UInt8(1)) AS numwait |
| | Aggregate: groupBy=[[supplier.s_name]], aggr=[[COUNT(UInt8(1))]] |
| | Projection: supplier.s_name |
-| | LeftAnti Join: l1.l_orderkey = l3.l_orderkey Filter: l3.l_suppkey != l1.l_suppkey |
-| | LeftSemi Join: l1.l_orderkey = l2.l_orderkey Filter: l2.l_suppkey != l1.l_suppkey |
+| | LeftAnti Join: l1.l_orderkey = __correlated_sq_2.l_orderkey Filter: __correlated_sq_2.l_suppkey != l1.l_suppkey |
+| | LeftSemi Join: l1.l_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_suppkey != l1.l_suppkey |
| | Projection: supplier.s_name, l1.l_orderkey, l1.l_suppkey |
| | Inner Join: supplier.s_nationkey = nation.n_nationkey |
| | Projection: supplier.s_name, supplier.s_nationkey, l1.l_orderkey, l1.l_suppkey |
@@ -24,12 +24,14 @@
| | Projection: nation.n_nationkey |
| | Filter: nation.n_name = Utf8("SAUDI ARABIA") |
| | TableScan: nation projection=[n_nationkey, n_name] |
-| | SubqueryAlias: l2 |
-| | TableScan: lineitem projection=[l_orderkey, l_suppkey] |
-| | SubqueryAlias: l3 |
-| | Projection: lineitem.l_orderkey, lineitem.l_suppkey |
-| | Filter: lineitem.l_receiptdate > lineitem.l_commitdate |
-| | TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate] |
+| | SubqueryAlias: __correlated_sq_1 |
+| | SubqueryAlias: l2 |
+| | TableScan: lineitem projection=[l_orderkey, l_suppkey] |
+| | SubqueryAlias: __correlated_sq_2 |
+| | SubqueryAlias: l3 |
+| | Projection: lineitem.l_orderkey, lineitem.l_suppkey |
+| | Filter: lineitem.l_receiptdate > lineitem.l_commitdate |
+| | TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate] |
| physical_plan | SortPreservingMergeExec: [numwait@1 DESC,s_name@0 ASC NULLS LAST] |
| | SortExec: expr=[numwait@1 DESC,s_name@0 ASC NULLS LAST] |
| | ProjectionExec: expr=[s_name@0 as s_name, COUNT(UInt8(1))@1 as numwait] |
diff --git a/benchmarks/expected-plans/q22.txt b/benchmarks/expected-plans/q22.txt
index d05ae58f0a..a84830acea 100644
--- a/benchmarks/expected-plans/q22.txt
+++ b/benchmarks/expected-plans/q22.txt
@@ -9,10 +9,11 @@
| | Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_1.__value |
| | CrossJoin: |
| | Projection: customer.c_phone, customer.c_acctbal |
-| | LeftAnti Join: customer.c_custkey = orders.o_custkey |
+| | LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey |
| | Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) |
| | TableScan: customer projection=[c_custkey, c_phone, c_acctbal] |
-| | TableScan: orders projection=[o_custkey] |
+| | SubqueryAlias: __correlated_sq_1 |
+| | TableScan: orders projection=[o_custkey] |
| | SubqueryAlias: __scalar_sq_1 |
| | Projection: AVG(customer.c_acctbal) AS __value |
| | Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] |
diff --git a/benchmarks/expected-plans/q4.txt b/benchmarks/expected-plans/q4.txt
index c49d7f40e6..05dc5d00c0 100644
--- a/benchmarks/expected-plans/q4.txt
+++ b/benchmarks/expected-plans/q4.txt
@@ -5,13 +5,14 @@
| | Projection: orders.o_orderpriority, COUNT(UInt8(1)) AS order_count |
| | Aggregate: groupBy=[[orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] |
| | Projection: orders.o_orderpriority |
-| | LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey |
+| | LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey |
| | Projection: orders.o_orderkey, orders.o_orderpriority |
| | Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate < Date32("8674") |
| | TableScan: orders projection=[o_orderkey, o_orderdate, o_orderpriority] |
-| | Projection: lineitem.l_orderkey |
-| | Filter: lineitem.l_commitdate < lineitem.l_receiptdate |
-| | TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate] |
+| | SubqueryAlias: __correlated_sq_1 |
+| | Projection: lineitem.l_orderkey |
+| | Filter: lineitem.l_commitdate < lineitem.l_receiptdate |
+| | TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate] |
| physical_plan | SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST] |
| | SortExec: expr=[o_orderpriority@0 ASC NULLS LAST] |
| | ProjectionExec: expr=[o_orderpriority@0 as o_orderpriority, COUNT(UInt8(1))@1 as order_count] |
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 54ecfe4951..57343ea95c 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -656,13 +656,13 @@ async fn reduce_left_join_1() -> Result<()> {
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Filter: t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ];
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -693,7 +693,7 @@ async fn reduce_left_join_2() -> Result<()> {
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ];
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -714,18 +714,18 @@ async fn reduce_left_join_3() -> Result<()> {
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Left Join: t3.t1_int = t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N]",
- " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Projection: t2.t2_id [t2_id:UInt32;N]",
- " Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ];
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Left Join: t3.t1_int = t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " SubqueryAlias: t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N]",
+ " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Projection: t2.t2_id [t2_id:UInt32;N]",
+ " Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -746,12 +746,12 @@ async fn reduce_right_join_1() -> Result<()> {
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: t1.t1_int IS NOT NULL [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ];
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t1.t1_int IS NOT NULL [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -776,7 +776,7 @@ async fn reduce_right_join_2() -> Result<()> {
" Inner Join: t1.t1_id = t2.t2_id Filter: t1.t1_int != t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ];
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -797,12 +797,12 @@ async fn reduce_full_join_to_right_join() -> Result<()> {
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Right Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Filter: t2.t2_name IS NOT NULL [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ];
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Right Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: t2.t2_name IS NOT NULL [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -824,12 +824,12 @@ async fn reduce_full_join_to_left_join() -> Result<()> {
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Left Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ];
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Left Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -850,13 +850,13 @@ async fn reduce_full_join_to_inner_join() -> Result<()> {
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Filter: t2.t2_name = Utf8(\"x\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ];
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t1.t1_name != Utf8(\"b\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: t2.t2_name = Utf8(\"x\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -1011,8 +1011,7 @@ async fn left_semi_join() -> Result<()> {
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
- " ProjectionExec: expr=[t2_id@0 as t2_id]",
- " MemoryExec: partitions=1, partition_sizes=[1]",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
]
} else {
vec![
@@ -1020,8 +1019,7 @@ async fn left_semi_join() -> Result<()> {
" CoalesceBatchesExec: target_batch_size=4096",
" HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2_id\", index: 0 })]",
" MemoryExec: partitions=1, partition_sizes=[1]",
- " ProjectionExec: expr=[t2_id@0 as t2_id]",
- " MemoryExec: partitions=1, partition_sizes=[1]",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
]
};
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
@@ -1271,17 +1269,17 @@ async fn right_semi_join() -> Result<()> {
let physical_plan = dataframe.create_physical_plan().await?;
let expected = if repartition_joins {
vec!["SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]",
- " SortExec: expr=[t1_id@0 ASC NULLS LAST]",
- " CoalesceBatchesExec: target_batch_size=4096",
- " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }",
- " CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2",
- " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
- " MemoryExec: partitions=1, partition_sizes=[1]",
- " CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2",
- " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
- " MemoryExec: partitions=1, partition_sizes=[1]",
+ " SortExec: expr=[t1_id@0 ASC NULLS LAST]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2",
+ " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2",
+ " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
]
} else {
vec![
@@ -1315,17 +1313,17 @@ async fn right_semi_join() -> Result<()> {
let physical_plan = dataframe.create_physical_plan().await?;
let expected = if repartition_joins {
vec!["SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]",
- " SortExec: expr=[t1_id@0 ASC NULLS LAST]",
- " CoalesceBatchesExec: target_batch_size=4096",
- " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 0 }, op: NotEq, right: Column { name: \"t1_name\", index: 1 } }",
- " CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2",
- " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
- " MemoryExec: partitions=1, partition_sizes=[1]",
- " CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2",
- " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
- " MemoryExec: partitions=1, partition_sizes=[1]",
+ " SortExec: expr=[t1_id@0 ASC NULLS LAST]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 0 }, op: NotEq, right: Column { name: \"t1_name\", index: 1 } }",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2",
+ " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2",
+ " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
]
} else {
vec![
@@ -1532,11 +1530,11 @@ async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> {
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Inner Join: CAST(t1.t1_id AS Int64) + Int64(12) = CAST(t2.t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- ];
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Inner Join: CAST(t1.t1_id AS Int64) + Int64(12) = CAST(t2.t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -1553,17 +1551,17 @@ async fn reduce_cross_join_with_cast_expr_join_key() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;
let sql =
- "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = cast(t2.t2_id as BIGINT)";
+ "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = cast(t2.t2_id as BIGINT)";
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan()?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, t2_id:UInt32;N, t1_name:Utf8;N]",
- " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
- " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
- " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
- ];
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, t2_id:UInt32;N, t1_name:Utf8;N]",
+ " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ ];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -1623,7 +1621,7 @@ async fn reduce_cross_join_with_wildcard_and_expr() -> Result<()> {
" ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(t2.t2_id AS Int64)]",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" MemoryExec: partitions=1, partition_sizes=[1]",
- ]
+ ]
} else {
vec![
"ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + Int64(11)]",
@@ -1738,7 +1736,7 @@ async fn left_side_expr_key_inner_join() -> Result<()> {
" RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" MemoryExec: partitions=1, partition_sizes=[1]",
- ]
+ ]
} else {
vec![
"ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]",
@@ -1793,7 +1791,7 @@ async fn right_side_expr_key_inner_join() -> Result<()> {
" ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as t2.t2_id - UInt32(11)]",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" MemoryExec: partitions=1, partition_sizes=[1]",
- ]
+ ]
} else {
vec![
"ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]",
@@ -2412,9 +2410,10 @@ async fn exists_subquery_to_join_expr_filter() -> Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -2454,11 +2453,12 @@ async fn exists_subquery_to_join_inner_filter() -> Result<()> {
// `t2.t2_int < 3` will be kept in the subquery filter.
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Projection: t2.t2_id [t2_id:UInt32;N]",
- " Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, t2_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]",
+ " Projection: t2.t2_id [t2_id:UInt32;N]",
+ " Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -2496,10 +2496,11 @@ async fn exists_subquery_to_join_outer_filter() -> Result<()> {
// `t1.t1_int < 3` will be moved to the filter of t1.
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: t1.t1_int < UInt32(3) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -2536,9 +2537,10 @@ async fn not_exists_subquery_to_join_expr_filter() -> Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -2574,10 +2576,11 @@ async fn exists_distinct_subquery_to_join() -> Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]",
- " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]",
+ " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -2614,10 +2617,11 @@ async fn exists_distinct_subquery_to_join_with_expr() -> Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]",
- " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]",
+ " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
@@ -2654,10 +2658,11 @@ async fn exists_distinct_subquery_to_join_with_literal() -> Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
- " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(__correlated_sq_1.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]",
- " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]",
+ " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
index d1e71536bd..b93e0e0c8b 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -98,7 +98,7 @@ where o_orderstatus in (
\n LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey\
\n TableScan: orders projection=[o_orderkey, o_orderstatus]\
\n SubqueryAlias: __correlated_sq_1\
- \n Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey\
+ \n Projection: lineitem.l_linestatus, lineitem.l_orderkey\
\n TableScan: lineitem projection=[l_orderkey, l_linestatus]";
assert_eq!(actual, expected);
@@ -162,7 +162,7 @@ async fn in_subquery_with_same_table() -> Result<()> {
" LeftSemi Join: t1.t1_id = __correlated_sq_1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" SubqueryAlias: __correlated_sq_1 [t1_int:UInt32;N]",
- " Projection: t1.t1_int AS t1_int [t1_int:UInt32;N]",
+ " Projection: t1.t1_int [t1_int:UInt32;N]",
" Filter: t1.t1_id > t1.t1_int [t1_id:UInt32;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_int] [t1_id:UInt32;N, t1_int:UInt32;N]",
];
@@ -176,6 +176,36 @@ async fn in_subquery_with_same_table() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn in_subquery_nested_exist_subquery() -> Result<()> {
+ let ctx = create_join_context("t1_id", "t2_id", true)?;
+
+ let sql = "SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t2_id FROM t2 WHERE EXISTS(select * from t1 WHERE t1.t1_int > t2.t2_int))";
+ let msg = format!("Creating logical plan for '{sql}'");
+ let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+ let plan = dataframe.into_optimized_plan()?;
+
+ let expected = vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_1 [t2_id:UInt32;N]",
+ " Projection: t2.t2_id [t2_id:UInt32;N]",
+ " LeftSemi Join: Filter: __correlated_sq_2.t1_int > t2.t2_int [t2_id:UInt32;N, t2_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
+ " SubqueryAlias: __correlated_sq_2 [t1_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_int] [t1_int:UInt32;N]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+ );
+
+ Ok(())
+}
+
#[tokio::test]
async fn invalid_scalar_subquery() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", true)?;
diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
similarity index 52%
rename from datafusion/optimizer/src/decorrelate_where_in.rs
rename to datafusion/optimizer/src/decorrelate_predicate_subquery.rs
index 0d9b472cf4..8630c60649 100644
--- a/datafusion/optimizer/src/decorrelate_where_in.rs
+++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
@@ -22,26 +22,31 @@ use crate::utils::{
replace_qualified_name, split_conjunction,
};
use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{context, Column, Result};
-use datafusion_expr::expr::InSubquery;
+use datafusion_common::{context, Column, DataFusionError, Result};
+use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::expr_rewriter::unnormalize_col;
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
-use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder};
+use datafusion_expr::{
+ exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Distinct, Expr, Filter,
+ LogicalPlan, LogicalPlanBuilder, Operator,
+};
use log::debug;
+use std::ops::Deref;
use std::sync::Arc;
+/// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins
#[derive(Default)]
-pub struct DecorrelateWhereIn {
+pub struct DecorrelatePredicateSubquery {
alias: AliasGenerator,
}
-impl DecorrelateWhereIn {
+impl DecorrelatePredicateSubquery {
#[allow(missing_docs)]
pub fn new() -> Self {
Self::default()
}
- /// Finds expressions that have a where in subquery (and recurses when found)
+ /// Finds expressions that have the predicate subqueries (and recurses when found)
///
/// # Arguments
///
@@ -54,7 +59,7 @@ impl DecorrelateWhereIn {
predicate: &Expr,
config: &dyn OptimizerConfig,
) -> Result<(Vec<SubqueryInfo>, Vec<Expr>)> {
- let filters = split_conjunction(predicate); // TODO: disjunctions
+ let filters = split_conjunction(predicate); // TODO: add ExistenceJoin to support disjunctions
let mut subqueries = vec![];
let mut others = vec![];
@@ -70,12 +75,19 @@ impl DecorrelateWhereIn {
.map(Arc::new)
.unwrap_or_else(|| subquery.subquery.clone());
let new_subquery = subquery.with_plan(subquery_plan);
- subqueries.push(SubqueryInfo::new(
+ subqueries.push(SubqueryInfo::new_with_in_expr(
new_subquery,
(**expr).clone(),
*negated,
));
- // TODO: if subquery doesn't get optimized, optimized children are lost
+ }
+ Expr::Exists(Exists { subquery, negated }) => {
+ let subquery_plan = self
+ .try_optimize(&subquery.subquery, config)?
+ .map(Arc::new)
+ .unwrap_or_else(|| subquery.subquery.clone());
+ let new_subquery = subquery.with_plan(subquery_plan);
+ subqueries.push(SubqueryInfo::new(new_subquery, *negated));
}
_ => others.push((*it).clone()),
}
@@ -85,7 +97,7 @@ impl DecorrelateWhereIn {
}
}
-impl OptimizerRule for DecorrelateWhereIn {
+impl OptimizerRule for DecorrelatePredicateSubquery {
fn try_optimize(
&self,
plan: &LogicalPlan,
@@ -93,7 +105,7 @@ impl OptimizerRule for DecorrelateWhereIn {
) -> Result<Option<LogicalPlan>> {
match plan {
LogicalPlan::Filter(filter) => {
- let (subqueries, other_exprs) =
+ let (subqueries, mut other_exprs) =
self.extract_subquery_exprs(&filter.predicate, config)?;
if subqueries.is_empty() {
// regular filter, no subquery exists clause here
@@ -103,7 +115,34 @@ impl OptimizerRule for DecorrelateWhereIn {
// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = filter.input.as_ref().clone();
for subquery in subqueries {
- cur_input = optimize_where_in(&subquery, &cur_input, &self.alias)?;
+ if let Some(plan) = build_join(&subquery, &cur_input, &self.alias)? {
+ cur_input = plan;
+ } else {
+ // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
+ let sub_query_expr = match subquery {
+ SubqueryInfo {
+ query,
+ where_in_expr: Some(expr),
+ negated: false,
+ } => in_subquery(expr, query.subquery.clone()),
+ SubqueryInfo {
+ query,
+ where_in_expr: Some(expr),
+ negated: true,
+ } => not_in_subquery(expr, query.subquery.clone()),
+ SubqueryInfo {
+ query,
+ where_in_expr: None,
+ negated: false,
+ } => exists(query.subquery.clone()),
+ SubqueryInfo {
+ query,
+ where_in_expr: None,
+ negated: true,
+ } => not_exists(query.subquery.clone()),
+ };
+ other_exprs.push(sub_query_expr);
+ }
}
let expr = conjunction(other_exprs);
@@ -111,7 +150,6 @@ impl OptimizerRule for DecorrelateWhereIn {
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
cur_input = LogicalPlan::Filter(new_filter);
}
-
Ok(Some(cur_input))
}
_ => Ok(None),
@@ -119,7 +157,7 @@ impl OptimizerRule for DecorrelateWhereIn {
}
fn name(&self) -> &str {
- "decorrelate_where_in"
+ "decorrelate_predicate_subquery"
}
fn apply_order(&self) -> Option<ApplyOrder> {
@@ -127,7 +165,7 @@ impl OptimizerRule for DecorrelateWhereIn {
}
}
-/// Optimize the where in subquery to left-anti/left-semi join.
+/// Optimize the subquery to left-anti/left-semi join.
/// If the subquery is a correlated subquery, we need extract the join predicate from the subquery.
///
/// For example, given a query like:
@@ -140,89 +178,198 @@ impl OptimizerRule for DecorrelateWhereIn {
/// LeftSemi Join: Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c
/// TableScan: t1
/// SubqueryAlias: __correlated_sq_1
-/// Projection: t2.a AS a, t2.b, t2.c
+/// Projection: t2.a, t2.b, t2.c
/// TableScan: t2
/// ```
-fn optimize_where_in(
+///
+/// Given another query like:
+/// `select t1.id from t1 where exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)`
+///
+/// The optimized plan will be:
+///
+/// ```text
+/// Projection: t1.id
+/// LeftSemi Join: Filter: t1.id = __correlated_sq_1.id
+/// TableScan: t1
+/// SubqueryAlias: __correlated_sq_1
+/// Projection: t2.id
+/// TableScan: t2
+/// ```
+fn build_join(
query_info: &SubqueryInfo,
left: &LogicalPlan,
alias: &AliasGenerator,
-) -> Result<LogicalPlan> {
- let projection = Projection::try_from_plan(&query_info.query.subquery)
- .map_err(|e| context!("a projection is required", e))?;
- let subquery_input = projection.input.clone();
- // TODO add the validate logic to Analyzer
- let subquery_expr = only_or_err(projection.expr.as_slice())
- .map_err(|e| context!("single expression projection required", e))?;
-
- // extract join filters
- let (join_filters, subquery_input) = extract_join_filters(subquery_input.as_ref())?;
-
- // in_predicate may be also include in the join filters, remove it from the join filters.
- let in_predicate = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone());
- let join_filters = remove_duplicated_filter(join_filters, in_predicate);
-
- // replace qualified name with subquery alias.
+) -> Result<Option<LogicalPlan>> {
+ let in_predicate = query_info
+ .where_in_expr
+ .clone()
+ .map(|in_expr| {
+ let projection = Projection::try_from_plan(&query_info.query.subquery)
+ .map_err(|e| context!("a projection is required", e))?;
+ // TODO add the validate logic to Analyzer
+ let subquery_expr = only_or_err(projection.expr.as_slice())
+ .map_err(|e| context!("single expression projection required", e))?;
+
+ // in_predicate may be also include in the join filters
+ Ok(Expr::eq(in_expr, subquery_expr.clone()))
+ })
+ .map_or(Ok(None), |v: Result<Expr, DataFusionError>| v.map(Some))?;
+
+ let subquery = query_info.query.subquery.as_ref();
let subquery_alias = alias.next("__correlated_sq");
- let input_schema = subquery_input.schema();
- let mut subquery_cols = collect_subquery_cols(&join_filters, input_schema.clone())?;
- let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| {
- replace_qualified_name(filter, &subquery_cols, &subquery_alias).map(Option::Some)
- })?;
-
- // add projection
- if let Expr::Column(col) = subquery_expr {
- subquery_cols.remove(col);
+ if let Some((join_filter, subquery_plan)) =
+ pull_up_correlated_expr(subquery, in_predicate, &subquery_alias)?
+ {
+ let sub_query_alias = LogicalPlanBuilder::from(subquery_plan)
+ .alias(subquery_alias.clone())?
+ .build()?;
+ // join our sub query into the main plan
+ let join_type = match query_info.negated {
+ true => JoinType::LeftAnti,
+ false => JoinType::LeftSemi,
+ };
+ let new_plan = LogicalPlanBuilder::from(left.clone())
+ .join(
+ sub_query_alias,
+ join_type,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some(join_filter),
+ )?
+ .build()?;
+ debug!(
+ "predicate subquery optimized:\n{}",
+ new_plan.display_indent()
+ );
+ Ok(Some(new_plan))
+ } else {
+ Ok(None)
}
- let subquery_expr_name = format!("{:?}", unnormalize_col(subquery_expr.clone()));
- let first_expr = subquery_expr.clone().alias(subquery_expr_name.clone());
- let projection_exprs: Vec<Expr> = [first_expr]
- .into_iter()
- .chain(subquery_cols.into_iter().map(Expr::Column))
- .collect();
-
- let right = LogicalPlanBuilder::from(subquery_input)
- .project(projection_exprs)?
- .alias(subquery_alias.clone())?
- .build()?;
-
- // join our sub query into the main plan
- let join_type = match query_info.negated {
- true => JoinType::LeftAnti,
- false => JoinType::LeftSemi,
- };
- let right_join_col = Column::new(Some(subquery_alias), subquery_expr_name);
- let in_predicate = Expr::eq(
- query_info.where_in_expr.clone(),
- Expr::Column(right_join_col),
- );
- let join_filter = join_filter
- .map(|filter| in_predicate.clone().and(filter))
- .unwrap_or_else(|| in_predicate);
-
- let new_plan = LogicalPlanBuilder::from(left.clone())
- .join(
- right,
- join_type,
- (Vec::<Column>::new(), Vec::<Column>::new()),
- Some(join_filter),
- )?
- .build()?;
-
- debug!("where in optimized:\n{}", new_plan.display_indent());
- Ok(new_plan)
}
-fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: Expr) -> Vec<Expr> {
+/// This function pull up the correlated expressions(contains outer reference columns) from the inner subquery's [Filter].
+/// It adds the inner reference columns to the [Projection] of the subquery if they are missing, so that they can be evaluated by the parent operator as the join condition.
+///
+/// This function can't handle the non-correlated subquery, and will return None.
+fn pull_up_correlated_expr(
+ subquery: &LogicalPlan,
+ in_predicate_opt: Option<Expr>,
+ subquery_alias: &str,
+) -> Result<Option<(Expr, LogicalPlan)>> {
+ match subquery {
+ LogicalPlan::Distinct(subqry_distinct) => {
+ let distinct_input = &subqry_distinct.input;
+ let optimized_plan = pull_up_correlated_expr(
+ distinct_input,
+ in_predicate_opt,
+ subquery_alias,
+ )?
+ .map(|(filters, right)| {
+ (
+ filters,
+ LogicalPlan::Distinct(Distinct {
+ input: Arc::new(right),
+ }),
+ )
+ });
+ Ok(optimized_plan)
+ }
+ LogicalPlan::Projection(projection) => {
+ // extract join filters from the inner subquery's Filter
+ let (mut join_filters, subquery_input) =
+ extract_join_filters(&projection.input)?;
+ if in_predicate_opt.is_none() && join_filters.is_empty() {
+ // cannot rewrite non-correlated subquery
+ return Ok(None);
+ }
+
+ if let Some(in_predicate) = &in_predicate_opt {
+ // in_predicate may be already included in the join filters, remove it from the join filters first.
+ join_filters = remove_duplicated_filter(join_filters, in_predicate);
+ }
+ let input_schema = subquery_input.schema();
+ let correlated_subquery_cols =
+ collect_subquery_cols(&join_filters, input_schema.clone())?;
+
+ // add missing columns to projection
+ let mut project_exprs: Vec<Expr> =
+ if let Some(Expr::BinaryExpr(BinaryExpr {
+ left: _,
+ op: Operator::Eq,
+ right,
+ })) = &in_predicate_opt
+ {
+ if !matches!(right.deref(), Expr::Column(_)) {
+ vec![right.deref().clone().alias(format!(
+ "{:?}",
+ unnormalize_col(right.deref().clone())
+ ))]
+ } else {
+ vec![right.deref().clone()]
+ }
+ } else {
+ vec![]
+ };
+ // the inner reference cols need to added to the projection if they are missing.
+ for col in correlated_subquery_cols.iter() {
+ let col_expr = Expr::Column(col.clone());
+ if !project_exprs.contains(&col_expr) {
+ project_exprs.push(col_expr)
+ }
+ }
+
+ // alias the join filter
+ let join_filter_opt =
+ conjunction(join_filters).map_or(Ok(None), |filter| {
+ replace_qualified_name(
+ filter,
+ &correlated_subquery_cols,
+ subquery_alias,
+ )
+ .map(Option::Some)
+ })?;
+
+ let join_filter = if let Some(Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: Operator::Eq,
+ right,
+ })) = in_predicate_opt
+ {
+ let right_expr_name =
+ format!("{:?}", unnormalize_col(right.deref().clone()));
+ let right_col =
+ Column::new(Some(subquery_alias.to_string()), right_expr_name);
+ let in_predicate =
+ Expr::eq(left.deref().clone(), Expr::Column(right_col));
+ join_filter_opt
+ .map(|filter| in_predicate.clone().and(filter))
+ .unwrap_or_else(|| in_predicate)
+ } else {
+ join_filter_opt.ok_or_else(|| {
+ DataFusionError::Internal(
+ "join filters should not be empty".to_string(),
+ )
+ })?
+ };
+
+ let right = LogicalPlanBuilder::from(subquery_input)
+ .project(project_exprs)?
+ .build()?;
+ Ok(Some((join_filter, right)))
+ }
+ _ => Ok(None),
+ }
+}
+
+fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: &Expr) -> Vec<Expr> {
filters
.into_iter()
.filter(|filter| {
- if filter == &in_predicate {
+ if filter == in_predicate {
return false;
}
// ignore the binary order
- !match (filter, &in_predicate) {
+ !match (filter, in_predicate) {
(Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => {
(a_expr.op == b_expr.op)
&& (a_expr.left == b_expr.left && a_expr.right == b_expr.right)
@@ -236,15 +383,23 @@ fn remove_duplicated_filter(filters: Vec<Expr>, in_predicate: Expr) -> Vec<Expr>
struct SubqueryInfo {
query: Subquery,
- where_in_expr: Expr,
+ where_in_expr: Option<Expr>,
negated: bool,
}
impl SubqueryInfo {
- pub fn new(query: Subquery, expr: Expr, negated: bool) -> Self {
+ pub fn new(query: Subquery, negated: bool) -> Self {
+ Self {
+ query,
+ where_in_expr: None,
+ negated,
+ }
+ }
+
+ pub fn new_with_in_expr(query: Subquery, expr: Expr, negated: bool) -> Self {
Self {
query,
- where_in_expr: expr,
+ where_in_expr: Some(expr),
negated,
}
}
@@ -257,14 +412,15 @@ mod tests {
use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_expr::{
- and, binary_expr, col, in_subquery, lit, logical_plan::LogicalPlanBuilder,
- not_in_subquery, or, out_ref_col, Operator,
+ and, binary_expr, col, exists, in_subquery, lit,
+ logical_plan::LogicalPlanBuilder, not_exists, not_in_subquery, or, out_ref_col,
+ Operator,
};
use std::ops::Add;
fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
plan,
expected,
);
@@ -297,10 +453,10 @@ mod tests {
\n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
- \n Projection: sq_1.c AS c [c:UInt32]\
+ \n Projection: sq_1.c [c:UInt32]\
\n TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_2 [c:UInt32]\
- \n Projection: sq_2.c AS c [c:UInt32]\
+ \n Projection: sq_2.c [c:UInt32]\
\n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_equal(&plan, expected)
}
@@ -325,7 +481,7 @@ mod tests {
\n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
- \n Projection: sq.c AS c [c:UInt32]\
+ \n Projection: sq.c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_equal(&plan, expected)
@@ -379,7 +535,7 @@ mod tests {
\n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
- \n Projection: sq2.c AS c [c:UInt32]\
+ \n Projection: sq2.c [c:UInt32]\
\n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_equal(&plan, expected)
@@ -404,11 +560,11 @@ mod tests {
\n LeftSemi Join: Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [a:UInt32]\
- \n Projection: sq.a AS a [a:UInt32]\
+ \n Projection: sq.a [a:UInt32]\
\n LeftSemi Join: Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_2 [c:UInt32]\
- \n Projection: sq_nested.c AS c [c:UInt32]\
+ \n Projection: sq_nested.c [c:UInt32]\
\n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_equal(&plan, expected)
@@ -440,7 +596,7 @@ mod tests {
\n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
- \n Projection: sq_inner.c AS c [c:UInt32]\
+ \n Projection: sq_inner.c [c:UInt32]\
\n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_equal(&plan, expected)
@@ -479,14 +635,14 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -528,15 +684,15 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\
- \n Projection: lineitem.l_orderkey AS l_orderkey [l_orderkey:Int64]\
+ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\
\n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -566,12 +722,12 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -600,11 +756,11 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -630,12 +786,12 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -664,11 +820,11 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -697,11 +853,11 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -731,11 +887,11 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\
+ \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -759,7 +915,7 @@ mod tests {
// Maybe okay if the table only has a single column?
assert_optimizer_err(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
"a projection is required",
);
@@ -788,11 +944,11 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -825,7 +981,7 @@ mod tests {
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -854,7 +1010,7 @@ mod tests {
.build()?;
assert_optimizer_err(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
"single expression projection required",
);
@@ -887,11 +1043,11 @@ mod tests {
\n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
\n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
\n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
- \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
\n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -929,7 +1085,7 @@ mod tests {
TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -955,11 +1111,11 @@ mod tests {
\n LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\
- \n Projection: sq.c AS c, sq.a [c:UInt32, a:UInt32]\
+ \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -979,11 +1135,11 @@ mod tests {
\n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
- \n Projection: sq.c AS c [c:UInt32]\
+ \n Projection: sq.c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -1003,11 +1159,11 @@ mod tests {
\n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
- \n Projection: sq.c AS c [c:UInt32]\
+ \n Projection: sq.c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -1036,7 +1192,7 @@ mod tests {
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -1071,7 +1227,7 @@ mod tests {
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -1107,7 +1263,7 @@ mod tests {
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -1153,7 +1309,7 @@ mod tests {
\n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
@@ -1179,15 +1335,572 @@ mod tests {
\n LeftSemi Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n SubqueryAlias: __correlated_sq_1 [c:UInt32]\
- \n Projection: test.c AS c [c:UInt32]\
+ \n Projection: test.c [c:UInt32]\
\n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
&plan,
expected,
);
Ok(())
}
+
+ /// Test for multiple exists subqueries in the same filter expression
+ #[test]
+ fn multiple_exists_subqueries() -> Result<()> {
+ let orders = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ col("orders.o_custkey")
+ .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(orders.clone()).and(exists(orders)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test recursive correlated subqueries
+ #[test]
+ fn recursive_exists_subqueries() -> Result<()> {
+ let lineitem = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
+ .filter(
+ col("lineitem.l_orderkey")
+ .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
+ )?
+ .project(vec![col("lineitem.l_orderkey")])?
+ .build()?,
+ );
+
+ let orders = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ exists(lineitem).and(
+ col("orders.o_custkey")
+ .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
+ ),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(orders))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n LeftSemi Join: Filter: __correlated_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\
+ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\
+ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]";
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test for correlated exists subquery filter with additional subquery filters
+ #[test]
+ fn exists_subquery_with_subquery_filters() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ out_ref_col(DataType::Int64, "customer.c_custkey")
+ .eq(col("orders.o_custkey"))
+ .and(col("o_orderkey").eq(lit(1))),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn exists_subquery_no_cols() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(out_ref_col(DataType::Int64, "customer.c_custkey").eq(lit(1u32)))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // Other rule will pushdown `customer.c_custkey = 1`,
+ // TODO revisit the logic, is it a valid physical plan when no cols in projection?
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 []\
+ \n Projection: []\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test for exists subquery with both columns in schema
+ #[test]
+ fn exists_subquery_with_no_correlated_cols() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan)
+ }
+
+ /// Test for correlated exists subquery not equal
+ #[test]
+ fn exists_subquery_where_not_eq() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ out_ref_col(DataType::Int64, "customer.c_custkey")
+ .not_eq(col("orders.o_custkey")),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test for correlated exists subquery less than
+ #[test]
+ fn exists_subquery_where_less_than() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ out_ref_col(DataType::Int64, "customer.c_custkey")
+ .lt(col("orders.o_custkey")),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test for correlated exists subquery filter with subquery disjunction
+ #[test]
+ fn exists_subquery_with_subquery_disjunction() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ out_ref_col(DataType::Int64, "customer.c_custkey")
+ .eq(col("orders.o_custkey"))
+ .or(col("o_orderkey").eq(lit(1))),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\
+ \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test for correlated exists without projection
+ #[test]
+ fn exists_subquery_no_projection() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan)
+ }
+
+ /// Test for correlated exists expressions
+ #[test]
+ fn exists_subquery_project_expr() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ out_ref_col(DataType::Int64, "customer.c_custkey")
+ .eq(col("orders.o_custkey")),
+ )?
+ .project(vec![col("orders.o_custkey").add(lit(1))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test for correlated exists subquery filter with additional filters
+ #[test]
+ fn exists_subquery_should_support_additional_filters() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ out_ref_col(DataType::Int64, "customer.c_custkey")
+ .eq(col("orders.o_custkey")),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq).and(col("c_custkey").eq(lit(1))))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
+ \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\
+ \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\
+ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
+ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\
+ \n Projection: orders.o_custkey [o_custkey:Int64]\
+ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test for correlated exists subquery filter with disjustions
+ #[test]
+ fn exists_subquery_disjunction() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq).or(col("customer.c_custkey").eq(lit(1))))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // not optimized
+ let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
+ Filter: EXISTS (<subquery>) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
+ Subquery: [o_custkey:Int64]
+ Projection: orders.o_custkey [o_custkey:Int64]
+ Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test for correlated EXISTS subquery filter
+ #[test]
+ fn exists_subquery_correlated() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
+ .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
+ .project(vec![col("c")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
+ .filter(exists(sq))?
+ .project(vec![col("test.c")])?
+ .build()?;
+
+ let expected = "Projection: test.c [c:UInt32]\
+ \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\
+ \n Projection: sq.a [a:UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ /// Test for single exists subquery filter
+ #[test]
+ fn exists_subquery_simple() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(exists(test_subquery_with_name("sq")?))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan)
+ }
+
+ /// Test for single NOT exists subquery filter
+ #[test]
+ fn not_exists_subquery_simple() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(not_exists(test_subquery_with_name("sq")?))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan)
+ }
+
+ #[test]
+ fn two_exists_subquery_with_outer_filter() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan1 = test_table_scan_with_name("sq1")?;
+ let subquery_scan2 = test_table_scan_with_name("sq2")?;
+
+ let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
+ .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq1.a")))?
+ .project(vec![col("c")])?
+ .build()?;
+
+ let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
+ .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq2.a")))?
+ .project(vec![col("c")])?
+ .build()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(
+ exists(Arc::new(subquery1))
+ .and(exists(Arc::new(subquery2)).and(col("test.c").gt(lit(1u32)))),
+ )?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "Projection: test.b [b:UInt32]\
+ \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\
+ \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\
+ \n Projection: sq1.a [a:UInt32]\
+ \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\
+ \n Projection: sq2.a [a:UInt32]\
+ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn exists_subquery_expr_filter() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan = test_table_scan_with_name("sq")?;
+ let subquery = LogicalPlanBuilder::from(subquery_scan)
+ .filter(
+ (lit(1u32) + col("sq.a"))
+ .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
+ )?
+ .project(vec![lit(1u32)])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(exists(Arc::new(subquery)))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "Projection: test.b [b:UInt32]\
+ \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\
+ \n Projection: sq.a [a:UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn exists_subquery_with_same_table() -> Result<()> {
+ let outer_scan = test_table_scan()?;
+ let subquery_scan = test_table_scan()?;
+ let subquery = LogicalPlanBuilder::from(subquery_scan)
+ .filter(col("test.a").gt(col("test.b")))?
+ .project(vec![col("c")])?
+ .build()?;
+
+ let plan = LogicalPlanBuilder::from(outer_scan)
+ .filter(exists(Arc::new(subquery)))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ // Subquery and outer query refer to the same table.
+ let expected = "Projection: test.b [b:UInt32]\
+ \n Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
+ \n Subquery: [c:UInt32]\
+ \n Projection: test.c [c:UInt32]\
+ \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn exists_distinct_subquery() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan = test_table_scan_with_name("sq")?;
+ let subquery = LogicalPlanBuilder::from(subquery_scan)
+ .filter(
+ (lit(1u32) + col("sq.a"))
+ .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
+ )?
+ .project(vec![col("sq.c")])?
+ .distinct()?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(exists(Arc::new(subquery)))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "Projection: test.b [b:UInt32]\
+ \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\
+ \n Distinct: [a:UInt32]\
+ \n Projection: sq.a [a:UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn exists_distinct_expr_subquery() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan = test_table_scan_with_name("sq")?;
+ let subquery = LogicalPlanBuilder::from(subquery_scan)
+ .filter(
+ (lit(1u32) + col("sq.a"))
+ .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
+ )?
+ .project(vec![col("sq.b") + col("sq.c")])?
+ .distinct()?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(exists(Arc::new(subquery)))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "Projection: test.b [b:UInt32]\
+ \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\
+ \n Distinct: [a:UInt32]\
+ \n Projection: sq.a [a:UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
+
+ #[test]
+ fn exists_distinct_subquery_with_literal() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let subquery_scan = test_table_scan_with_name("sq")?;
+ let subquery = LogicalPlanBuilder::from(subquery_scan)
+ .filter(
+ (lit(1u32) + col("sq.a"))
+ .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
+ )?
+ .project(vec![lit(1u32), col("sq.c")])?
+ .distinct()?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(exists(Arc::new(subquery)))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "Projection: test.b [b:UInt32]\
+ \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
+ \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\
+ \n Distinct: [a:UInt32]\
+ \n Projection: sq.a [a:UInt32]\
+ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_equal(&plan, expected)
+ }
}
diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs
deleted file mode 100644
index 0f143d4533..0000000000
--- a/datafusion/optimizer/src/decorrelate_where_exists.rs
+++ /dev/null
@@ -1,781 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use crate::optimizer::ApplyOrder;
-use crate::utils::{
- collect_subquery_cols, conjunction, extract_join_filters, split_conjunction,
-};
-use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{Column, DataFusionError, Result};
-use datafusion_expr::expr::Exists;
-use datafusion_expr::{
- logical_plan::{Distinct, Filter, JoinType, Subquery},
- Expr, LogicalPlan, LogicalPlanBuilder,
-};
-
-use std::sync::Arc;
-
-/// Optimizer rule for rewriting subquery filters to joins
-#[derive(Default)]
-pub struct DecorrelateWhereExists {}
-
-impl DecorrelateWhereExists {
- #[allow(missing_docs)]
- pub fn new() -> Self {
- Self {}
- }
-
- /// Finds expressions that have a where in subquery (and recurse when found)
- ///
- /// # Arguments
- ///
- /// * `predicate` - A conjunction to split and search
- /// * `optimizer_config` - For generating unique subquery aliases
- ///
- /// Returns a tuple (subqueries, non-subquery expressions)
- fn extract_subquery_exprs(
- &self,
- predicate: &Expr,
- config: &dyn OptimizerConfig,
- ) -> Result<(Vec<SubqueryInfo>, Vec<Expr>)> {
- let filters = split_conjunction(predicate);
-
- let mut subqueries = vec![];
- let mut others = vec![];
- for it in filters.iter() {
- match it {
- Expr::Exists(Exists { subquery, negated }) => {
- let subquery_plan = self
- .try_optimize(&subquery.subquery, config)?
- .map(Arc::new)
- .unwrap_or_else(|| subquery.subquery.clone());
- let new_subquery = subquery.with_plan(subquery_plan);
- subqueries.push(SubqueryInfo::new(new_subquery, *negated));
- }
- _ => others.push((*it).clone()),
- }
- }
-
- Ok((subqueries, others))
- }
-}
-
-impl OptimizerRule for DecorrelateWhereExists {
- fn try_optimize(
- &self,
- plan: &LogicalPlan,
- config: &dyn OptimizerConfig,
- ) -> Result<Option<LogicalPlan>> {
- match plan {
- LogicalPlan::Filter(filter) => {
- let (subqueries, other_exprs) =
- self.extract_subquery_exprs(&filter.predicate, config)?;
- if subqueries.is_empty() {
- // regular filter, no subquery exists clause here
- return Ok(None);
- }
-
- // iterate through all exists clauses in predicate, turning each into a join
- let mut cur_input = filter.input.as_ref().clone();
- for subquery in subqueries {
- if let Some(x) = optimize_exists(&subquery, &cur_input)? {
- cur_input = x;
- } else {
- return Ok(None);
- }
- }
-
- let expr = conjunction(other_exprs);
- if let Some(expr) = expr {
- let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
- cur_input = LogicalPlan::Filter(new_filter);
- }
-
- Ok(Some(cur_input))
- }
- _ => Ok(None),
- }
- }
-
- fn name(&self) -> &str {
- "decorrelate_where_exists"
- }
-
- fn apply_order(&self) -> Option<ApplyOrder> {
- Some(ApplyOrder::TopDown)
- }
-}
-
-/// Takes a query like:
-///
-/// SELECT t1.id
-/// FROM t1
-/// WHERE exists
-/// (
-/// SELECT t2.id FROM t2 WHERE t1.id = t2.id
-/// )
-///
-/// and optimizes it into:
-///
-/// SELECT t1.id
-/// FROM t1 LEFT SEMI
-/// JOIN t2
-/// ON t1.id = t2.id
-///
-/// # Arguments
-///
-/// * query_info - The subquery and negated(exists/not exists) info.
-/// * outer_input - The non-subquery portion (relation t1)
-fn optimize_exists(
- query_info: &SubqueryInfo,
- outer_input: &LogicalPlan,
-) -> Result<Option<LogicalPlan>> {
- let subquery = query_info.query.subquery.as_ref();
- if let Some((join_filter, optimized_subquery)) = optimize_subquery(subquery)? {
- // join our sub query into the main plan
- let join_type = match query_info.negated {
- true => JoinType::LeftAnti,
- false => JoinType::LeftSemi,
- };
-
- let new_plan = LogicalPlanBuilder::from(outer_input.clone())
- .join(
- optimized_subquery,
- join_type,
- (Vec::<Column>::new(), Vec::<Column>::new()),
- Some(join_filter),
- )?
- .build()?;
-
- Ok(Some(new_plan))
- } else {
- Ok(None)
- }
-}
-/// Optimize the subquery and extract the possible join filter.
-/// This function can't optimize non-correlated subquery, and will return None.
-fn optimize_subquery(subquery: &LogicalPlan) -> Result<Option<(Expr, LogicalPlan)>> {
- match subquery {
- LogicalPlan::Distinct(subqry_distinct) => {
- let distinct_input = &subqry_distinct.input;
- let optimized_plan =
- optimize_subquery(distinct_input)?.map(|(filters, right)| {
- (
- filters,
- LogicalPlan::Distinct(Distinct {
- input: Arc::new(right),
- }),
- )
- });
- Ok(optimized_plan)
- }
- LogicalPlan::Projection(projection) => {
- // extract join filters
- let (join_filters, subquery_input) = extract_join_filters(&projection.input)?;
- // cannot optimize non-correlated subquery
- if join_filters.is_empty() {
- return Ok(None);
- }
- let input_schema = subquery_input.schema();
- let project_exprs: Vec<Expr> =
- collect_subquery_cols(&join_filters, input_schema.clone())?
- .into_iter()
- .map(Expr::Column)
- .collect();
- let right = LogicalPlanBuilder::from(subquery_input)
- .project(project_exprs)?
- .build()?;
-
- // join_filters is not empty.
- let join_filter = conjunction(join_filters).ok_or_else(|| {
- DataFusionError::Internal("join filters should not be empty".to_string())
- })?;
- Ok(Some((join_filter, right)))
- }
- _ => Ok(None),
- }
-}
-
-struct SubqueryInfo {
- query: Subquery,
- negated: bool,
-}
-
-impl SubqueryInfo {
- pub fn new(query: Subquery, negated: bool) -> Self {
- Self { query, negated }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::test::*;
- use arrow::datatypes::DataType;
- use datafusion_common::Result;
- use datafusion_expr::{
- col, exists, lit, logical_plan::LogicalPlanBuilder, not_exists, out_ref_col,
- };
- use std::ops::Add;
-
- fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
- assert_optimized_plan_eq_display_indent(
- Arc::new(DecorrelateWhereExists::new()),
- plan,
- expected,
- );
- Ok(())
- }
-
- /// Test for multiple exists subqueries in the same filter expression
- #[test]
- fn multiple_subqueries() -> Result<()> {
- let orders = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(
- col("orders.o_custkey")
- .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
- )?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(orders.clone()).and(exists(orders)))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
- \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n Projection: orders.o_custkey [o_custkey:Int64]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n Projection: orders.o_custkey [o_custkey:Int64]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
- assert_plan_eq(&plan, expected)
- }
-
- /// Test recursive correlated subqueries
- #[test]
- fn recursive_subqueries() -> Result<()> {
- let lineitem = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
- .filter(
- col("lineitem.l_orderkey")
- .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
- )?
- .project(vec![col("lineitem.l_orderkey")])?
- .build()?,
- );
-
- let orders = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(
- exists(lineitem).and(
- col("orders.o_custkey")
- .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
- ),
- )?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(orders))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n Projection: orders.o_custkey [o_custkey:Int64]\
- \n LeftSemi Join: Filter: lineitem.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\
- \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]";
- assert_plan_eq(&plan, expected)
- }
-
- /// Test for correlated exists subquery filter with additional subquery filters
- #[test]
- fn exists_subquery_with_subquery_filters() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(
- out_ref_col(DataType::Int64, "customer.c_custkey")
- .eq(col("orders.o_custkey"))
- .and(col("o_orderkey").eq(lit(1))),
- )?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n Projection: orders.o_custkey [o_custkey:Int64]\
- \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
-
- assert_plan_eq(&plan, expected)
- }
-
- #[test]
- fn exists_subquery_no_cols() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(out_ref_col(DataType::Int64, "customer.c_custkey").eq(lit(1u32)))?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- // Other rule will pushdown `customer.c_custkey = 1`,
- let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n Projection: []\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
-
- assert_plan_eq(&plan, expected)
- }
-
- /// Test for exists subquery with both columns in schema
- #[test]
- fn exists_subquery_with_no_correlated_cols() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
- }
-
- /// Test for correlated exists subquery not equal
- #[test]
- fn exists_subquery_where_not_eq() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(
- out_ref_col(DataType::Int64, "customer.c_custkey")
- .not_eq(col("orders.o_custkey")),
- )?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n Projection: orders.o_custkey [o_custkey:Int64]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
-
- assert_plan_eq(&plan, expected)
- }
-
- /// Test for correlated exists subquery less than
- #[test]
- fn exists_subquery_where_less_than() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(
- out_ref_col(DataType::Int64, "customer.c_custkey")
- .lt(col("orders.o_custkey")),
- )?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: Filter: customer.c_custkey < orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n Projection: orders.o_custkey [o_custkey:Int64]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
-
- assert_plan_eq(&plan, expected)
- }
-
- /// Test for correlated exists subquery filter with subquery disjunction
- #[test]
- fn exists_subquery_with_subquery_disjunction() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(
- out_ref_col(DataType::Int64, "customer.c_custkey")
- .eq(col("orders.o_custkey"))
- .or(col("o_orderkey").eq(lit(1))),
- )?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey OR orders.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
-
- assert_plan_eq(&plan, expected)
- }
-
- /// Test for correlated exists without projection
- #[test]
- fn exists_subquery_no_projection() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
- }
-
- /// Test for correlated exists expressions
- #[test]
- fn exists_subquery_project_expr() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(
- out_ref_col(DataType::Int64, "customer.c_custkey")
- .eq(col("orders.o_custkey")),
- )?
- .project(vec![col("orders.o_custkey").add(lit(1))])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n Projection: orders.o_custkey [o_custkey:Int64]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
-
- assert_plan_eq(&plan, expected)
- }
-
- /// Test for correlated exists subquery filter with additional filters
- #[test]
- fn should_support_additional_filters() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(
- out_ref_col(DataType::Int64, "customer.c_custkey")
- .eq(col("orders.o_custkey")),
- )?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq).and(col("c_custkey").eq(lit(1))))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- let expected = "Projection: customer.c_custkey [c_custkey:Int64]\
- \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\
- \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\
- \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\
- \n Projection: orders.o_custkey [o_custkey:Int64]\
- \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]";
-
- assert_plan_eq(&plan, expected)
- }
-
- /// Test for correlated exists subquery filter with disjunctions
- #[test]
- fn exists_subquery_disjunction() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(scan_tpch_table("orders"))
- .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
- .project(vec![col("orders.o_custkey")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
- .filter(exists(sq).or(col("customer.c_custkey").eq(lit(1))))?
- .project(vec![col("customer.c_custkey")])?
- .build()?;
-
- // not optimized
- let expected = r#"Projection: customer.c_custkey [c_custkey:Int64]
- Filter: EXISTS (<subquery>) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
- Subquery: [o_custkey:Int64]
- Projection: orders.o_custkey [o_custkey:Int64]
- Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
- TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
- TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
-
- assert_plan_eq(&plan, expected)
- }
-
- /// Test for correlated EXISTS subquery filter
- #[test]
- fn exists_subquery_correlated() -> Result<()> {
- let sq = Arc::new(
- LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
- .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
- .project(vec![col("c")])?
- .build()?,
- );
-
- let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
- .filter(exists(sq))?
- .project(vec![col("test.c")])?
- .build()?;
-
- let expected = "Projection: test.c [c:UInt32]\
- \n LeftSemi Join: Filter: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
- \n Projection: sq.a [a:UInt32]\
- \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
-
- assert_plan_eq(&plan, expected)
- }
-
- /// Test for single exists subquery filter
- #[test]
- fn exists_subquery_simple() -> Result<()> {
- let table_scan = test_table_scan()?;
- let plan = LogicalPlanBuilder::from(table_scan)
- .filter(exists(test_subquery_with_name("sq")?))?
- .project(vec![col("test.b")])?
- .build()?;
-
- assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
- }
-
- /// Test for single NOT exists subquery filter
- #[test]
- fn not_exists_subquery_simple() -> Result<()> {
- let table_scan = test_table_scan()?;
- let plan = LogicalPlanBuilder::from(table_scan)
- .filter(not_exists(test_subquery_with_name("sq")?))?
- .project(vec![col("test.b")])?
- .build()?;
-
- assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
- }
-
- #[test]
- fn two_exists_subquery_with_outer_filter() -> Result<()> {
- let table_scan = test_table_scan()?;
- let subquery_scan1 = test_table_scan_with_name("sq1")?;
- let subquery_scan2 = test_table_scan_with_name("sq2")?;
-
- let subquery1 = LogicalPlanBuilder::from(subquery_scan1)
- .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq1.a")))?
- .project(vec![col("c")])?
- .build()?;
-
- let subquery2 = LogicalPlanBuilder::from(subquery_scan2)
- .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq2.a")))?
- .project(vec![col("c")])?
- .build()?;
-
- let plan = LogicalPlanBuilder::from(table_scan)
- .filter(
- exists(Arc::new(subquery1))
- .and(exists(Arc::new(subquery2)).and(col("test.c").gt(lit(1u32)))),
- )?
- .project(vec![col("test.b")])?
- .build()?;
-
- let expected = "Projection: test.b [b:UInt32]\
- \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\
- \n LeftSemi Join: Filter: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\
- \n LeftSemi Join: Filter: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
- \n Projection: sq1.a [a:UInt32]\
- \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\
- \n Projection: sq2.a [a:UInt32]\
- \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]";
-
- assert_plan_eq(&plan, expected)
- }
-
- #[test]
- fn exists_subquery_expr_filter() -> Result<()> {
- let table_scan = test_table_scan()?;
- let subquery_scan = test_table_scan_with_name("sq")?;
- let subquery = LogicalPlanBuilder::from(subquery_scan)
- .filter(
- (lit(1u32) + col("sq.a"))
- .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
- )?
- .project(vec![lit(1u32)])?
- .build()?;
- let plan = LogicalPlanBuilder::from(table_scan)
- .filter(exists(Arc::new(subquery)))?
- .project(vec![col("test.b")])?
- .build()?;
-
- let expected = "Projection: test.b [b:UInt32]\
- \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
- \n Projection: sq.a [a:UInt32]\
- \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
-
- assert_plan_eq(&plan, expected)
- }
-
- #[test]
- fn exists_subquery_with_same_table() -> Result<()> {
- let outer_scan = test_table_scan()?;
- let subquery_scan = test_table_scan()?;
- let subquery = LogicalPlanBuilder::from(subquery_scan)
- .filter(col("test.a").gt(col("test.b")))?
- .project(vec![col("c")])?
- .build()?;
-
- let plan = LogicalPlanBuilder::from(outer_scan)
- .filter(exists(Arc::new(subquery)))?
- .project(vec![col("test.b")])?
- .build()?;
-
- // Subquery and outer query refer to the same table.
- let expected = "Projection: test.b [b:UInt32]\
- \n Filter: EXISTS (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
- \n Subquery: [c:UInt32]\
- \n Projection: test.c [c:UInt32]\
- \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
-
- assert_plan_eq(&plan, expected)
- }
-
- #[test]
- fn exists_distinct_subquery() -> Result<()> {
- let table_scan = test_table_scan()?;
- let subquery_scan = test_table_scan_with_name("sq")?;
- let subquery = LogicalPlanBuilder::from(subquery_scan)
- .filter(
- (lit(1u32) + col("sq.a"))
- .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
- )?
- .project(vec![col("sq.c")])?
- .distinct()?
- .build()?;
- let plan = LogicalPlanBuilder::from(table_scan)
- .filter(exists(Arc::new(subquery)))?
- .project(vec![col("test.b")])?
- .build()?;
-
- let expected = "Projection: test.b [b:UInt32]\
- \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
- \n Distinct: [a:UInt32]\
- \n Projection: sq.a [a:UInt32]\
- \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
-
- assert_plan_eq(&plan, expected)
- }
-
- #[test]
- fn exists_distinct_expr_subquery() -> Result<()> {
- let table_scan = test_table_scan()?;
- let subquery_scan = test_table_scan_with_name("sq")?;
- let subquery = LogicalPlanBuilder::from(subquery_scan)
- .filter(
- (lit(1u32) + col("sq.a"))
- .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
- )?
- .project(vec![col("sq.b") + col("sq.c")])?
- .distinct()?
- .build()?;
- let plan = LogicalPlanBuilder::from(table_scan)
- .filter(exists(Arc::new(subquery)))?
- .project(vec![col("test.b")])?
- .build()?;
-
- let expected = "Projection: test.b [b:UInt32]\
- \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
- \n Distinct: [a:UInt32]\
- \n Projection: sq.a [a:UInt32]\
- \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
-
- assert_plan_eq(&plan, expected)
- }
-
- #[test]
- fn exists_distinct_subquery_with_literal() -> Result<()> {
- let table_scan = test_table_scan()?;
- let subquery_scan = test_table_scan_with_name("sq")?;
- let subquery = LogicalPlanBuilder::from(subquery_scan)
- .filter(
- (lit(1u32) + col("sq.a"))
- .gt(out_ref_col(DataType::UInt32, "test.a") * lit(2u32)),
- )?
- .project(vec![lit(1u32), col("sq.c")])?
- .distinct()?
- .build()?;
- let plan = LogicalPlanBuilder::from(table_scan)
- .filter(exists(Arc::new(subquery)))?
- .project(vec![col("test.b")])?
- .build()?;
-
- let expected = "Projection: test.b [b:UInt32]\
- \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\
- \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
- \n Distinct: [a:UInt32]\
- \n Projection: sq.a [a:UInt32]\
- \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";
-
- assert_plan_eq(&plan, expected)
- }
-}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 42c0ccb484..2af5edbd9f 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -18,8 +18,7 @@
pub mod alias;
pub mod analyzer;
pub mod common_subexpr_eliminate;
-pub mod decorrelate_where_exists;
-pub mod decorrelate_where_in;
+pub mod decorrelate_predicate_subquery;
pub mod eliminate_cross_join;
pub mod eliminate_duplicated_expr;
pub mod eliminate_filter;
diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs
index e72833fae6..284280a3a5 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -18,8 +18,7 @@
//! Query optimizer traits
use crate::common_subexpr_eliminate::CommonSubexprEliminate;
-use crate::decorrelate_where_exists::DecorrelateWhereExists;
-use crate::decorrelate_where_in::DecorrelateWhereIn;
+use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery;
use crate::eliminate_cross_join::EliminateCrossJoin;
use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr;
use crate::eliminate_filter::EliminateFilter;
@@ -211,8 +210,7 @@ impl Optimizer {
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(ReplaceDistinctWithAggregate::new()),
- Arc::new(DecorrelateWhereExists::new()),
- Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelatePredicateSubquery::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Arc::new(ExtractEquijoinPredicate::new()),
// simplify expressions does not simplify expressions in subqueries, so we
diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs
index e8cced59fc..266d0a0be7 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -346,7 +346,7 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
}
}
-/// Extract join predicates from the correclated subquery.
+/// Extract join predicates from the correlated subquery's [Filter] expressions.
/// The join predicate means that the expression references columns
/// from both the subquery and outer table or only from the outer table.
///
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index 43f329843e..761d6539b2 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -120,10 +120,11 @@ fn semi_join_with_join_filter() -> Result<()> {
AND test.col_uint32 != t2.col_uint32)";
let plan = test_sql(sql)?;
let expected = "Projection: test.col_utf8\
- \n LeftSemi Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\
+ \n LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
- \n SubqueryAlias: t2\
- \n TableScan: test projection=[col_int32, col_uint32]";
+ \n SubqueryAlias: __correlated_sq_1\
+ \n SubqueryAlias: t2\
+ \n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
@@ -136,25 +137,26 @@ fn anti_join_with_join_filter() -> Result<()> {
AND test.col_uint32 != t2.col_uint32)";
let plan = test_sql(sql)?;
let expected = "Projection: test.col_utf8\
- \n LeftAnti Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\
+ \n LeftAnti Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
- \n SubqueryAlias: t2\
- \n TableScan: test projection=[col_int32, col_uint32]";
+ \n SubqueryAlias: __correlated_sq_1\
+ \n SubqueryAlias: t2\
+ \n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
#[test]
fn where_exists_distinct() -> Result<()> {
- // regression test for https://github.com/apache/arrow-datafusion/issues/3724
let sql = "SELECT col_int32 FROM test WHERE EXISTS (\
SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)";
let plan = test_sql(sql)?;
- let expected = "LeftSemi Join: test.col_int32 = t2.col_int32\
+ let expected = "LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32\
\n TableScan: test projection=[col_int32]\
- \n Aggregate: groupBy=[[t2.col_int32]], aggr=[[]]\
- \n SubqueryAlias: t2\
- \n TableScan: test projection=[col_int32]";
+ \n SubqueryAlias: __correlated_sq_1\
+ \n Aggregate: groupBy=[[t2.col_int32]], aggr=[[]]\
+ \n SubqueryAlias: t2\
+ \n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}