You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/07/22 16:06:20 UTC
[arrow-datafusion] branch master updated: Add support for correlated subqueries & fix all related TPC-H benchmark issues (#2885)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 7b0f2f846 Add support for correlated subqueries & fix all related TPC-H benchmark issues (#2885)
7b0f2f846 is described below
commit 7b0f2f846a7c8c2ffee2a4f29772cf3527a8d92c
Author: Brent Gardner <bg...@squarelabs.net>
AuthorDate: Fri Jul 22 10:06:14 2022 -0600
Add support for correlated subqueries & fix all related TPC-H benchmark issues (#2885)
* Failing test case for TPC-H query 20
* Fix name
* Broken test for adding intervals to dates
* Tests pass
* Fix rebase
* Fix query
* Additional tests
* Reduce to minimum failing (and passing) cases
* Adjust so data _should_ be returned, but see none
* Fixed data, decorrelated test passes
* Check in plans
* Put real assertion in place
* Add test for already working subquery optimizer
* Add decorellator
* Check in broken test
* Add some passing and failing tests to see scope of problem
* Have almost all inputs needed for optimization, but need to catch 1 level earlier in tree
* Collected all inputs, now we just need to optimize
* Successfully decorrelated query 4
* refactor
* Pass test 4
* Ready for PR?
* Only operate on equality expressions
* Lint error
* Tests still pass because we are losing remaining predicate
* Don't lose remaining expressions
* Update test to expect remaining filter clause
* Debugging
* Can run query 4
* Remove debugging code
* Clippy
* Refactor where exists, add scalar subquery
* Login qty < () and 0.2 times, predicate pushdown is killing our plan
* Query plan looks good
* Fudge data to make test output nicer
* Fix syntax error
* [WIP] where in
* Working recursively, q20 plan looks good, but execution failing
* Fix CSV for execution error, remove silly variables in favor of --nocapture
* Silence verbose logs
* Query 21 test
* [WIP] refactoring, query 4 looking good
* [WIP] 4 & 17 look good
* 22 good?
* Check in "Test" for query 11
* query 11 works
* Don't throw away plans when multiple subqueries in one filter
* Manually decorellate query 21
* [WIP] add data for query 21, anti join failing for some reason
* Does appear to be problem with anti-join
* Minimum failing test
* Verify anti join fix
* Repeatable tests
* cargo fmt
* Restore some optimizers and update test expectations
* Restore some optimizers and update test expectations
* Restore some optimizers and update test expectations
* Restore some optimizers and update test expectations
* Cleanup
* Cleanup scalar subquery, de-duplicate some code
* Cleanup
* Refactor
* Refactor
* Refactor
* Refactor
* Handle recursive where in
* Update assertions
* Support recursion in where exists queries
* Unit tests on where in
* Add correlated where in test
* Nasty code to make where in work for both correlated and uncorrelated queries
* Cleanup
* Refactoring
* Refactoring
* Add correlated unit test
* Add correlated where exists unit test
* [WIP] Failing scalar subquery unit test
* Refactor
* tuple mixup
* Scalar subquery unit test
* ASF header
* PR feedback
* PR feedback
* PR feedback
* PR feedback
* Fix build again
* Formatting
* Testing
* multiple where in
* Unit tests for where in
* where exists tests
* scalar subquery tests
* add aggregates to scalar subqueries
* Remove tests that only existed to get logical plans as input to unit tests
* Check in assertions for valid tests
* 1/33 passing unit tests :/
* Down to one failing test
* All the unit tests pass
* into methods
* Where exists unit tests passing
* Try from methods
* Fix tests
* Fix tests
* Refactor
* Fix test
* Refactor
* Fix test
* Fix error message
* Fix tests
* Fix tests
* Refactor
* Refactor and fix tests
* Improved recursive subquery test
* Recursive subquery fix
* Update tests
* Update tests
* Update tests
* Doc
* Clippy
* Linter & clippy
* Add doc, move test methods into test modules
* PR cleanup
* Inline test data
* Remove shared test data
* Remove shared test data
* Update tests
* Fix toml
* Update expectation
* PR feedback
* PR feedback
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Fix test to reveal logic error
* Simplify test
* Fix stuff, break other stuff
* I've writen scala in rust because I'm in a hurry :(
* Clean the API up a little
* PR feedback
* PR feedback
* PR feedback
* PR feedback
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
benchmarks/queries/q20.sql | 2 +-
datafusion/common/src/error.rs | 27 +
datafusion/core/Cargo.toml | 2 +
datafusion/core/src/execution/context.rs | 6 +
.../core/src/physical_plan/coalesce_batches.rs | 4 +-
.../core/src/physical_plan/file_format/mod.rs | 4 +-
datafusion/core/tests/sql/mod.rs | 111 +++-
datafusion/core/tests/sql/subqueries.rs | 474 ++++++++++++++
datafusion/core/tests/tpch-csv/lineitem.csv | 4 +-
datafusion/core/tests/tpch-csv/nation.csv | 2 +-
datafusion/core/tests/tpch-csv/part.csv | 2 +
datafusion/core/tests/tpch-csv/partsupp.csv | 2 +
datafusion/core/tests/tpch-csv/region.csv | 2 +
datafusion/core/tests/tpch-csv/supplier.csv | 3 +
datafusion/expr/src/expr.rs | 9 +-
datafusion/expr/src/logical_plan/plan.rs | 36 +-
datafusion/optimizer/Cargo.toml | 5 +
.../optimizer/src/decorrelate_scalar_subquery.rs | 705 +++++++++++++++++++++
.../optimizer/src/decorrelate_where_exists.rs | 557 ++++++++++++++++
datafusion/optimizer/src/decorrelate_where_in.rs | 693 ++++++++++++++++++++
datafusion/optimizer/src/lib.rs | 3 +
datafusion/optimizer/src/test/mod.rs | 77 ++-
datafusion/optimizer/src/utils.rs | 235 ++++++-
23 files changed, 2952 insertions(+), 13 deletions(-)
diff --git a/benchmarks/queries/q20.sql b/benchmarks/queries/q20.sql
index f0339a601..dd61a7d8e 100644
--- a/benchmarks/queries/q20.sql
+++ b/benchmarks/queries/q20.sql
@@ -28,7 +28,7 @@ where
l_partkey = ps_partkey
and l_suppkey = ps_suppkey
and l_shipdate >= date '1994-01-01'
- and l_shipdate < 'date 1994-01-01' + interval '1' year
+ and l_shipdate < date '1994-01-01' + interval '1' year
)
)
and s_nationkey = n_nationkey
diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs
index c1d0f29b1..de5bbe8e0 100644
--- a/datafusion/common/src/error.rs
+++ b/datafusion/common/src/error.rs
@@ -83,6 +83,30 @@ pub enum DataFusionError {
#[cfg(feature = "jit")]
/// Error occurs during code generation
JITError(ModuleError),
+ /// Error with additional context
+ Context(String, Box<DataFusionError>),
+}
+
+#[macro_export]
+macro_rules! context {
+ ($desc:expr, $err:expr) => {
+ datafusion_common::DataFusionError::Context(
+ format!("{} at {}:{}", $desc, file!(), line!()),
+ Box::new($err),
+ )
+ };
+}
+
+#[macro_export]
+macro_rules! plan_err {
+ ($desc:expr) => {
+ Err(datafusion_common::DataFusionError::Plan(format!(
+ "{} at {}:{}",
+ $desc,
+ file!(),
+ line!()
+ )))
+ };
}
/// Schema-related errors
@@ -285,6 +309,9 @@ impl Display for DataFusionError {
DataFusionError::ObjectStore(ref desc) => {
write!(f, "Object Store error: {}", desc)
}
+ DataFusionError::Context(ref desc, ref err) => {
+ write!(f, "{}\ncaused by\n{}", desc, *err)
+ }
}
}
}
diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml
index ac21d7f90..351e72fc2 100644
--- a/datafusion/core/Cargo.toml
+++ b/datafusion/core/Cargo.toml
@@ -94,6 +94,8 @@ uuid = { version = "1.0", features = ["v4"] }
[dev-dependencies]
criterion = "0.3"
+csv = "1.1.6"
+ctor = "0.1.22"
doc-comment = "0.3"
env_logger = "0.9"
fuzz-utils = { path = "fuzz-utils" }
diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index e37d3b0ba..41964e33a 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -102,6 +102,9 @@ use async_trait::async_trait;
use chrono::{DateTime, Utc};
use datafusion_common::ScalarValue;
use datafusion_expr::TableSource;
+use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery;
+use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
+use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_sql::{
parser::DFParser,
@@ -1356,6 +1359,9 @@ impl SessionState {
// Simplify expressions first to maximize the chance
// of applying other optimizations
Arc::new(SimplifyExpressions::new()),
+ Arc::new(DecorrelateWhereExists::new()),
+ Arc::new(DecorrelateWhereIn::new()),
+ Arc::new(DecorrelateScalarSubquery::new()),
Arc::new(SubqueryFilterToJoin::new()),
Arc::new(EliminateFilter::new()),
Arc::new(CommonSubexprEliminate::new()),
diff --git a/datafusion/core/src/physical_plan/coalesce_batches.rs b/datafusion/core/src/physical_plan/coalesce_batches.rs
index 3f39caaef..a257ccf09 100644
--- a/datafusion/core/src/physical_plan/coalesce_batches.rs
+++ b/datafusion/core/src/physical_plan/coalesce_batches.rs
@@ -35,7 +35,7 @@ use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use futures::stream::{Stream, StreamExt};
-use log::debug;
+use log::trace;
use super::expressions::PhysicalSortExpr;
use super::metrics::{BaselineMetrics, MetricsSet};
@@ -286,7 +286,7 @@ pub fn concat_batches(
)?;
arrays.push(array);
}
- debug!(
+ trace!(
"Combined {} batches containing {} rows",
batches.len(),
row_count
diff --git a/datafusion/core/src/physical_plan/file_format/mod.rs b/datafusion/core/src/physical_plan/file_format/mod.rs
index 3ea520b2c..c26b2d760 100644
--- a/datafusion/core/src/physical_plan/file_format/mod.rs
+++ b/datafusion/core/src/physical_plan/file_format/mod.rs
@@ -26,6 +26,8 @@ mod file_stream;
mod json;
mod parquet;
+pub(crate) use self::csv::plan_to_csv;
+pub use self::csv::CsvExec;
pub(crate) use self::parquet::plan_to_parquet;
pub use self::parquet::ParquetExec;
use arrow::{
@@ -36,8 +38,6 @@ use arrow::{
record_batch::RecordBatch,
};
pub use avro::AvroExec;
-pub(crate) use csv::plan_to_csv;
-pub use csv::CsvExec;
pub(crate) use json::plan_to_json;
pub use json::NdJsonExec;
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index a7f4cabe9..186584aeb 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -49,6 +49,7 @@ use datafusion_expr::Volatility;
use object_store::path::Path;
use std::fs::File;
use std::io::Write;
+use std::ops::Sub;
use std::path::PathBuf;
use tempfile::TempDir;
@@ -108,6 +109,7 @@ mod explain;
mod idenfifers;
pub mod information_schema;
mod partitioned_csv;
+mod subqueries;
#[cfg(feature = "unicode_expressions")]
pub mod unicode;
@@ -483,7 +485,43 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("n_comment", DataType::Utf8, false),
]),
- _ => unimplemented!(),
+ "supplier" => Schema::new(vec![
+ Field::new("s_suppkey", DataType::Int64, false),
+ Field::new("s_name", DataType::Utf8, false),
+ Field::new("s_address", DataType::Utf8, false),
+ Field::new("s_nationkey", DataType::Int64, false),
+ Field::new("s_phone", DataType::Utf8, false),
+ Field::new("s_acctbal", DataType::Float64, false),
+ Field::new("s_comment", DataType::Utf8, false),
+ ]),
+
+ "partsupp" => Schema::new(vec![
+ Field::new("ps_partkey", DataType::Int64, false),
+ Field::new("ps_suppkey", DataType::Int64, false),
+ Field::new("ps_availqty", DataType::Int32, false),
+ Field::new("ps_supplycost", DataType::Float64, false),
+ Field::new("ps_comment", DataType::Utf8, false),
+ ]),
+
+ "part" => Schema::new(vec![
+ Field::new("p_partkey", DataType::Int64, false),
+ Field::new("p_name", DataType::Utf8, false),
+ Field::new("p_mfgr", DataType::Utf8, false),
+ Field::new("p_brand", DataType::Utf8, false),
+ Field::new("p_type", DataType::Utf8, false),
+ Field::new("p_size", DataType::Int32, false),
+ Field::new("p_container", DataType::Utf8, false),
+ Field::new("p_retailprice", DataType::Float64, false),
+ Field::new("p_comment", DataType::Utf8, false),
+ ]),
+
+ "region" => Schema::new(vec![
+ Field::new("r_regionkey", DataType::Int64, false),
+ Field::new("r_name", DataType::Utf8, false),
+ Field::new("r_comment", DataType::Utf8, false),
+ ]),
+
+ _ => unimplemented!("Table: {}", table),
}
}
@@ -499,6 +537,77 @@ async fn register_tpch_csv(ctx: &SessionContext, table: &str) -> Result<()> {
Ok(())
}
+async fn register_tpch_csv_data(
+ ctx: &SessionContext,
+ table_name: &str,
+ data: &str,
+) -> Result<()> {
+ let schema = Arc::new(get_tpch_table_schema(table_name));
+
+ let mut reader = ::csv::ReaderBuilder::new()
+ .has_headers(false)
+ .from_reader(data.as_bytes());
+ let records: Vec<_> = reader.records().map(|it| it.unwrap()).collect();
+
+ let mut cols: Vec<Box<dyn ArrayBuilder>> = vec![];
+ for field in schema.fields().iter() {
+ match field.data_type() {
+ DataType::Utf8 => cols.push(Box::new(StringBuilder::new(records.len()))),
+ DataType::Date32 => cols.push(Box::new(Date32Builder::new(records.len()))),
+ DataType::Int32 => cols.push(Box::new(Int32Builder::new(records.len()))),
+ DataType::Int64 => cols.push(Box::new(Int64Builder::new(records.len()))),
+ DataType::Float64 => cols.push(Box::new(Float64Builder::new(records.len()))),
+ _ => {
+ let msg = format!("Not implemented: {}", field.data_type());
+ Err(DataFusionError::Plan(msg))?
+ }
+ }
+ }
+
+ for record in records.iter() {
+ for (idx, val) in record.iter().enumerate() {
+ let col = cols.get_mut(idx).unwrap();
+ let field = schema.field(idx);
+ match field.data_type() {
+ DataType::Utf8 => {
+ let sb = col.as_any_mut().downcast_mut::<StringBuilder>().unwrap();
+ sb.append_value(val)?;
+ }
+ DataType::Date32 => {
+ let sb = col.as_any_mut().downcast_mut::<Date32Builder>().unwrap();
+ let dt = NaiveDate::parse_from_str(val.trim(), "%Y-%m-%d").unwrap();
+ let dt = dt.sub(NaiveDate::from_ymd(1970, 1, 1)).num_days() as i32;
+ sb.append_value(dt)?;
+ }
+ DataType::Int32 => {
+ let sb = col.as_any_mut().downcast_mut::<Int32Builder>().unwrap();
+ sb.append_value(val.trim().parse().unwrap())?;
+ }
+ DataType::Int64 => {
+ let sb = col.as_any_mut().downcast_mut::<Int64Builder>().unwrap();
+ sb.append_value(val.trim().parse().unwrap())?;
+ }
+ DataType::Float64 => {
+ let sb = col.as_any_mut().downcast_mut::<Float64Builder>().unwrap();
+ sb.append_value(val.trim().parse().unwrap())?;
+ }
+ _ => Err(DataFusionError::Plan(format!(
+ "Not implemented: {}",
+ field.data_type()
+ )))?,
+ }
+ }
+ }
+ let cols: Vec<ArrayRef> = cols.iter_mut().map(|it| it.finish()).collect();
+
+ let batch = RecordBatch::try_new(Arc::clone(&schema), cols)?;
+
+ let table = Arc::new(MemTable::try_new(Arc::clone(&schema), vec![vec![batch]])?);
+ let _ = ctx.register_table(table_name, table).unwrap();
+
+ Ok(())
+}
+
async fn register_aggregate_csv_by_sql(ctx: &SessionContext) {
let testdata = datafusion::test_util::arrow_test_data();
diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
new file mode 100644
index 000000000..4eaf921f6
--- /dev/null
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -0,0 +1,474 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+ let _ = env_logger::try_init();
+}
+
+#[tokio::test]
+async fn correlated_recursive_scalar_subquery() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_tpch_csv(&ctx, "customer").await?;
+ register_tpch_csv(&ctx, "orders").await?;
+ register_tpch_csv(&ctx, "lineitem").await?;
+
+ let sql = r#"
+select c_custkey from customer
+where c_acctbal < (
+ select sum(o_totalprice) from orders
+ where o_custkey = c_custkey
+ and o_totalprice < (
+ select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey
+ )
+) order by c_custkey;"#;
+
+ // assert plan
+ let plan = ctx.create_logical_plan(sql).unwrap();
+ debug!("input:\n{}", plan.display_indent());
+
+ let plan = ctx.optimize(&plan).unwrap();
+ let actual = format!("{}", plan.display_indent());
+ let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
+ Projection: #customer.c_custkey
+ Filter: #customer.c_acctbal < #__sq_2.__value
+ Inner Join: #customer.c_custkey = #__sq_2.o_custkey
+ TableScan: customer projection=[c_custkey, c_acctbal]
+ Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, alias=__sq_2
+ Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[SUM(#orders.o_totalprice)]]
+ Filter: #orders.o_totalprice < #__sq_1.__value
+ Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
+ TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]
+ Projection: #lineitem.l_orderkey, #SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
+ Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[SUM(#lineitem.l_extendedprice)]]
+ TableScan: lineitem projection=[l_orderkey, l_extendedprice]"#
+ .to_string();
+ assert_eq!(actual, expected);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn correlated_where_in() -> Result<()> {
+ let orders = r#"1,3691,O,194029.55,1996-01-02,5-LOW,Clerk#000000951,0,
+65,1627,P,99763.79,1995-03-18,1-URGENT,Clerk#000000632,0,
+"#;
+ let lineitems = r#"1,15519,785,1,17,24386.67,0.04,0.02,N,O,1996-03-13,1996-02-12,1996-03-22,DELIVER IN PERSON,TRUCK,
+1,6731,732,2,36,58958.28,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE BACK RETURN,MAIL,
+65,5970,481,1,26,48775.22,0.03,0.03,A,F,1995-04-20,1995-04-25,1995-05-13,NONE,TRUCK,
+65,7382,897,2,22,28366.36,0,0.05,N,O,1995-07-17,1995-06-04,1995-07-19,COLLECT COD,FOB,
+"#;
+
+ let ctx = SessionContext::new();
+ register_tpch_csv_data(&ctx, "orders", orders).await?;
+ register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+ let sql = r#"select o_orderkey from orders
+where o_orderstatus in (
+ select l_linestatus from lineitem where l_orderkey = orders.o_orderkey
+);"#;
+
+ // assert plan
+ let plan = ctx.create_logical_plan(sql).unwrap();
+ let plan = ctx.optimize(&plan).unwrap();
+ let actual = format!("{}", plan.display_indent());
+ let expected = r#"Projection: #orders.o_orderkey
+ Semi Join: #orders.o_orderstatus = #__sq_1.l_linestatus, #orders.o_orderkey = #__sq_1.l_orderkey
+ TableScan: orders projection=[o_orderkey, o_orderstatus]
+ Projection: #lineitem.l_linestatus AS l_linestatus, #lineitem.l_orderkey AS l_orderkey, alias=__sq_1
+ TableScan: lineitem projection=[l_orderkey, l_linestatus]"#
+ .to_string();
+ assert_eq!(actual, expected);
+
+ // assert data
+ let results = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------------+",
+ "| o_orderkey |",
+ "+------------+",
+ "| 1 |",
+ "+------------+",
+ ];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q2_correlated() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_tpch_csv(&ctx, "part").await?;
+ register_tpch_csv(&ctx, "supplier").await?;
+ register_tpch_csv(&ctx, "partsupp").await?;
+ register_tpch_csv(&ctx, "nation").await?;
+ register_tpch_csv(&ctx, "region").await?;
+
+ let sql = r#"select s_acctbal, s_name, n_name, p_partkey, p_mfgr, s_address, s_phone, s_comment
+from part, supplier, partsupp, nation, region
+where p_partkey = ps_partkey and s_suppkey = ps_suppkey and p_size = 15 and p_type like '%BRASS'
+ and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 'EUROPE'
+ and ps_supplycost = (
+ select min(ps_supplycost) from partsupp, supplier, nation, region
+ where p_partkey = ps_partkey and s_suppkey = ps_suppkey and s_nationkey = n_nationkey
+ and n_regionkey = r_regionkey and r_name = 'EUROPE'
+ )
+order by s_acctbal desc, n_name, s_name, p_partkey;"#;
+
+ // assert plan
+ let plan = ctx.create_logical_plan(sql).unwrap();
+ let plan = ctx.optimize(&plan).unwrap();
+ let actual = format!("{}", plan.display_indent());
+ let expected = r#"Sort: #supplier.s_acctbal DESC NULLS FIRST, #nation.n_name ASC NULLS LAST, #supplier.s_name ASC NULLS LAST, #part.p_partkey ASC NULLS LAST
+ Projection: #supplier.s_acctbal, #supplier.s_name, #nation.n_name, #part.p_partkey, #part.p_mfgr, #supplier.s_address, #supplier.s_phone, #supplier.s_comment
+ Filter: #partsupp.ps_supplycost = #__sq_1.__value
+ Inner Join: #part.p_partkey = #__sq_1.ps_partkey
+ Inner Join: #nation.n_regionkey = #region.r_regionkey
+ Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+ Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+ Inner Join: #part.p_partkey = #partsupp.ps_partkey
+ Filter: #part.p_size = Int64(15) AND #part.p_type LIKE Utf8("%BRASS")
+ TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE Utf8("%BRASS")]
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
+ TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: #region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name], partial_filters=[#region.r_name = Utf8("EUROPE")]
+ Projection: #partsupp.ps_partkey, #MIN(partsupp.ps_supplycost) AS __value, alias=__sq_1
+ Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[MIN(#partsupp.ps_supplycost)]]
+ Inner Join: #nation.n_regionkey = #region.r_regionkey
+ Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+ Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost]
+ TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+ TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+ Filter: #region.r_name = Utf8("EUROPE")
+ TableScan: region projection=[r_regionkey, r_name], partial_filters=[#region.r_name = Utf8("EUROPE")]"#
+ .to_string();
+ assert_eq!(actual, expected);
+
+ // assert data
+ let results = execute_to_batches(&ctx, sql).await;
+ let expected = vec!["++", "++"];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q4_correlated() -> Result<()> {
+ let orders = r#"4,13678,O,53829.87,1995-10-11,5-LOW,Clerk#000000124,0,
+35,12760,O,192885.43,1995-10-23,4-NOT SPECIFIED,Clerk#000000259,0,
+65,1627,P,99763.79,1995-03-18,1-URGENT,Clerk#000000632,0,
+"#;
+ let lineitems = r#"4,8804,579,1,30,51384,0.03,0.08,N,O,1996-01-10,1995-12-14,1996-01-18,DELIVER IN PERSON,REG AIR,
+35,45,296,1,24,22680.96,0.02,0,N,O,1996-02-21,1996-01-03,1996-03-18,TAKE BACK RETURN,FOB,
+65,5970,481,1,26,48775.22,0.03,0.03,A,F,1995-04-20,1995-04-25,1995-05-13,NONE,TRUCK,
+"#;
+
+ let ctx = SessionContext::new();
+ register_tpch_csv_data(&ctx, "orders", orders).await?;
+ register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+ let sql = r#"
+ select o_orderpriority, count(*) as order_count
+ from orders
+ where exists (
+ select * from lineitem where l_orderkey = o_orderkey and l_commitdate < l_receiptdate)
+ group by o_orderpriority
+ order by o_orderpriority;
+ "#;
+
+ // assert plan
+ let plan = ctx.create_logical_plan(sql).unwrap();
+ let plan = ctx.optimize(&plan).unwrap();
+ let actual = format!("{}", plan.display_indent());
+ let expected = r#"Sort: #orders.o_orderpriority ASC NULLS LAST
+ Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count
+ Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]]
+ Semi Join: #orders.o_orderkey = #lineitem.l_orderkey
+ TableScan: orders projection=[o_orderkey, o_orderpriority]
+ Filter: #lineitem.l_commitdate < #lineitem.l_receiptdate
+ TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate]"#
+ .to_string();
+ assert_eq!(actual, expected);
+
+ // assert data
+ let results = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------------+-------------+",
+ "| o_orderpriority | order_count |",
+ "+-----------------+-------------+",
+ "| 1-URGENT | 1 |",
+ "| 4-NOT SPECIFIED | 1 |",
+ "| 5-LOW | 1 |",
+ "+-----------------+-------------+",
+ ];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q17_correlated() -> Result<()> {
+ let parts = r#"63700,goldenrod lavender spring chocolate lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly ironi
+"#;
+ let lineitems = r#"1,63700,7311,2,36.0,45983.16,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE BACK RETURN,MAIL,ly final dependencies: slyly bold
+1,63700,3701,3,1.0,13309.6,0.1,0.02,N,O,1996-01-29,1996-03-05,1996-01-31,TAKE BACK RETURN,REG AIR,"riously. regular, express dep"
+"#;
+
+ let ctx = SessionContext::new();
+ register_tpch_csv_data(&ctx, "part", parts).await?;
+ register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+ let sql = r#"select sum(l_extendedprice) / 7.0 as avg_yearly
+ from lineitem, part
+ where p_partkey = l_partkey and p_brand = 'Brand#23' and p_container = 'MED BOX'
+ and l_quantity < (
+ select 0.2 * avg(l_quantity)
+ from lineitem where l_partkey = p_partkey
+ );"#;
+
+ // assert plan
+ let plan = ctx
+ .create_logical_plan(sql)
+ .map_err(|e| format!("{:?} at {}", e, "error"))
+ .unwrap();
+ println!("before:\n{}", plan.display_indent());
+ let plan = ctx
+ .optimize(&plan)
+ .map_err(|e| format!("{:?} at {}", e, "error"))
+ .unwrap();
+ let actual = format!("{}", plan.display_indent());
+ let expected = r#"Projection: #SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly
+ Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]]
+ Filter: #lineitem.l_quantity < #__sq_1.__value
+ Inner Join: #part.p_partkey = #__sq_1.l_partkey
+ Inner Join: #lineitem.l_partkey = #part.p_partkey
+ TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice]
+ Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX")
+ TableScan: part projection=[p_partkey, p_brand, p_container]
+ Projection: #lineitem.l_partkey, Float64(0.2) * #AVG(lineitem.l_quantity) AS __value, alias=__sq_1
+ Aggregate: groupBy=[[#lineitem.l_partkey]], aggr=[[AVG(#lineitem.l_quantity)]]
+ TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice]"#
+ .to_string();
+ assert_eq!(actual, expected);
+
+ // assert data
+ let results = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+--------------------+",
+ "| avg_yearly |",
+ "+--------------------+",
+ "| 1901.3714285714286 |",
+ "+--------------------+",
+ ];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q20_correlated() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_tpch_csv(&ctx, "supplier").await?;
+ register_tpch_csv(&ctx, "nation").await?;
+ register_tpch_csv(&ctx, "partsupp").await?;
+ register_tpch_csv(&ctx, "part").await?;
+ register_tpch_csv(&ctx, "lineitem").await?;
+
+ let sql = r#"select s_name, s_address
+from supplier, nation
+where s_suppkey in (
+ select ps_suppkey from partsupp
+ where ps_partkey in ( select p_partkey from part where p_name like 'forest%' )
+ and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem
+ where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate >= date '1994-01-01'
+ )
+)
+and s_nationkey = n_nationkey and n_name = 'CANADA'
+order by s_name;
+"#;
+
+ // assert plan
+ let plan = ctx
+ .create_logical_plan(sql)
+ .map_err(|e| format!("{:?} at {}", e, "error"))
+ .unwrap();
+ let plan = ctx
+ .optimize(&plan)
+ .map_err(|e| format!("{:?} at {}", e, "error"))
+ .unwrap();
+ let actual = format!("{}", plan.display_indent());
+ let expected = r#"Sort: #supplier.s_name ASC NULLS LAST
+ Projection: #supplier.s_name, #supplier.s_address
+ Semi Join: #supplier.s_suppkey = #__sq_2.ps_suppkey
+ Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+ TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey]
+ Filter: #nation.n_name = Utf8("CANADA")
+ TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("CANADA")]
+ Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
+ Filter: #partsupp.ps_availqty > #__sq_3.__value
+ Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey, #partsupp.ps_suppkey = #__sq_3.l_suppkey
+ Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
+ Projection: #part.p_partkey AS p_partkey, alias=__sq_1
+ Filter: #part.p_name LIKE Utf8("forest%")
+ TableScan: part projection=[p_partkey, p_name], partial_filters=[#part.p_name LIKE Utf8("forest%")]
+ Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) * #SUM(lineitem.l_quantity) AS __value, alias=__sq_3
+ Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]]
+ Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
+ TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
+ .to_string();
+ assert_eq!(actual, expected);
+
+ // assert data
+ let results = execute_to_batches(&ctx, sql).await;
+ let expected = vec!["++", "++"];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q22_correlated() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_tpch_csv(&ctx, "customer").await?;
+ register_tpch_csv(&ctx, "orders").await?;
+
+ let sql = r#"select cntrycode, count(*) as numcust, sum(c_acctbal) as totacctbal
+from (
+ select substring(c_phone from 1 for 2) as cntrycode, c_acctbal from customer
+ where substring(c_phone from 1 for 2) in ('13', '31', '23', '29', '30', '18', '17')
+ and c_acctbal > (
+ select avg(c_acctbal) from customer where c_acctbal > 0.00
+ and substring(c_phone from 1 for 2) in ('13', '31', '23', '29', '30', '18', '17')
+ )
+ and not exists ( select * from orders where o_custkey = c_custkey )
+ ) as custsale
+group by cntrycode
+order by cntrycode;"#;
+
+ // assert plan
+ let plan = ctx
+ .create_logical_plan(sql)
+ .map_err(|e| format!("{:?} at {}", e, "error"))
+ .unwrap();
+ let plan = ctx
+ .optimize(&plan)
+ .map_err(|e| format!("{:?} at {}", e, "error"))
+ .unwrap();
+ let actual = format!("{}", plan.display_indent());
+ let expected = r#"Sort: #custsale.cntrycode ASC NULLS LAST
+ Projection: #custsale.cntrycode, #COUNT(UInt8(1)) AS numcust, #SUM(custsale.c_acctbal) AS totacctbal
+ Aggregate: groupBy=[[#custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), SUM(#custsale.c_acctbal)]]
+ Projection: #custsale.cntrycode, #custsale.c_acctbal, alias=custsale
+ Projection: substr(#customer.c_phone, Int64(1), Int64(2)) AS cntrycode, #customer.c_acctbal, alias=custsale
+ Filter: #customer.c_acctbal > #__sq_1.__value
+ CrossJoin:
+ Anti Join: #customer.c_custkey = #orders.o_custkey
+ Filter: substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
+ TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]
+ TableScan: orders projection=[o_custkey]
+ Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
+ Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
+ Filter: #customer.c_acctbal > Float64(0) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
+ TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
+ .to_string();
+ assert_eq!(actual, expected);
+
+ // assert data
+ let results = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------+---------+------------+",
+ "| cntrycode | numcust | totacctbal |",
+ "+-----------+---------+------------+",
+ "| 18 | 1 | 8324.07 |",
+ "| 30 | 1 | 7638.57 |",
+ "+-----------+---------+------------+",
+ ];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q11_correlated() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_tpch_csv(&ctx, "partsupp").await?;
+ register_tpch_csv(&ctx, "supplier").await?;
+ register_tpch_csv(&ctx, "nation").await?;
+
+ let sql = r#"select ps_partkey, sum(ps_supplycost * ps_availqty) as value
+from partsupp, supplier, nation
+where ps_suppkey = s_suppkey and s_nationkey = n_nationkey and n_name = 'GERMANY'
+group by ps_partkey having
+ sum(ps_supplycost * ps_availqty) > (
+ select sum(ps_supplycost * ps_availqty) * 0.0001
+ from partsupp, supplier, nation
+ where ps_suppkey = s_suppkey and s_nationkey = n_nationkey and n_name = 'GERMANY'
+ )
+order by value desc;
+"#;
+
+ // assert plan
+ let plan = ctx
+ .create_logical_plan(sql)
+ .map_err(|e| format!("{:?} at {}", e, "error"))
+ .unwrap();
+ println!("before:\n{}", plan.display_indent());
+ let plan = ctx
+ .optimize(&plan)
+ .map_err(|e| format!("{:?} at {}", e, "error"))
+ .unwrap();
+ let actual = format!("{}", plan.display_indent());
+ println!("after:\n{}", actual);
+ let expected = r#"Sort: #value DESC NULLS FIRST
+ Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value
+ Filter: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) > #__sq_1.__value
+ CrossJoin:
+ Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[SUM(#partsupp.ps_supplycost * #partsupp.ps_availqty)]]
+ Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+ Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
+ TableScan: supplier projection=[s_suppkey, s_nationkey]
+ Filter: #nation.n_name = Utf8("GERMANY")
+ TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")]
+ Projection: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) AS __value, alias=__sq_1
+ Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost * #partsupp.ps_availqty)]]
+ Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+ Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+ TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
+ TableScan: supplier projection=[s_suppkey, s_nationkey]
+ Filter: #nation.n_name = Utf8("GERMANY")
+ TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")]"#
+ .to_string();
+ assert_eq!(actual, expected);
+
+ // assert data
+ let results = execute_to_batches(&ctx, sql).await;
+ let expected = vec!["++", "++"];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
diff --git a/datafusion/core/tests/tpch-csv/lineitem.csv b/datafusion/core/tests/tpch-csv/lineitem.csv
index 47f08711d..797a89180 100644
--- a/datafusion/core/tests/tpch-csv/lineitem.csv
+++ b/datafusion/core/tests/tpch-csv/lineitem.csv
@@ -1,5 +1,5 @@
l_orderkey,l_partkey,l_suppkey,l_linenumber,l_quantity,l_extendedprice,l_discount,l_tax,l_returnflag,l_linestatus,l_shipdate,l_commitdate,l_receiptdate,l_shipinstruct,l_shipmode,l_comment
-1,67310,7311,2,36.0,45983.16,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE BACK RETURN,MAIL,ly final dependencies: slyly bold
+1,67310,7311,2,36.0,45983.16,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE BACK RETURN,MAIL,ly final dependencies: slyly bold
1,63700,3701,3,8.0,13309.6,0.1,0.02,N,O,1996-01-29,1996-03-05,1996-01-31,TAKE BACK RETURN,REG AIR,"riously. regular, express dep"
1,2132,4633,4,28.0,28955.64,0.09,0.06,N,O,1996-04-21,1996-03-30,1996-05-16,NONE,AIR,lites. fluffily even de
1,24027,1534,5,24.0,22824.48,0.1,0.04,N,O,1996-03-30,1996-03-14,1996-04-01,NONE,FOB, pending foxes. slyly re
@@ -7,4 +7,4 @@ l_orderkey,l_partkey,l_suppkey,l_linenumber,l_quantity,l_extendedprice,l_discoun
2,106170,1191,1,38.0,44694.46,0.0,0.05,N,O,1997-01-28,1997-01-14,1997-02-02,TAKE BACK RETURN,RAIL,ven requests. deposits breach a
3,4297,1798,1,45.0,54058.05,0.06,0.0,R,F,1994-02-02,1994-01-04,1994-02-23,NONE,AIR,ongside of the furiously brave acco
3,19036,6540,2,49.0,46796.47,0.1,0.0,R,F,1993-11-09,1993-12-20,1993-11-24,TAKE BACK RETURN,RAIL, unusual accounts. eve
-3,128449,3474,3,27.0,39890.88,0.06,0.07,A,F,1994-01-16,1993-11-22,1994-01-23,DELIVER IN PERSON,SHIP,nal foxes wake.
+3,128449,3474,3,27.0,39890.88,0.06,0.07,A,F,1994-01-16,1993-11-22,1994-01-23,DELIVER IN PERSON,SHIP,nal foxes wake.
diff --git a/datafusion/core/tests/tpch-csv/nation.csv b/datafusion/core/tests/tpch-csv/nation.csv
index e37130f4a..4b3010596 100644
--- a/datafusion/core/tests/tpch-csv/nation.csv
+++ b/datafusion/core/tests/tpch-csv/nation.csv
@@ -8,4 +8,4 @@ n_nationkey,n_name,n_regionkey,n_comment
7,GERMANY,3,"l platelets. regular accounts x-ray: unusual, regular acco"
8,INDIA,2,ss excuses cajole slyly across the packages. deposits print aroun
9,INDONESIA,2, slyly express asymptotes. regular deposits haggle slyly. carefully ironic hockey players sleep blithely. carefull
-10,IRAN,4,efully alongside of the slyly final dependencies.
+10,IRAN,4,efully alongside of the slyly final dependencies.
diff --git a/datafusion/core/tests/tpch-csv/part.csv b/datafusion/core/tests/tpch-csv/part.csv
new file mode 100644
index 000000000..b505100ff
--- /dev/null
+++ b/datafusion/core/tests/tpch-csv/part.csv
@@ -0,0 +1,2 @@
+p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,p_retailprice,p_comment
+63700,goldenrod lavender spring chocolate lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly ironi
diff --git a/datafusion/core/tests/tpch-csv/partsupp.csv b/datafusion/core/tests/tpch-csv/partsupp.csv
new file mode 100644
index 000000000..d7db83d03
--- /dev/null
+++ b/datafusion/core/tests/tpch-csv/partsupp.csv
@@ -0,0 +1,2 @@
+ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment
+67310,7311,100,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff
diff --git a/datafusion/core/tests/tpch-csv/region.csv b/datafusion/core/tests/tpch-csv/region.csv
new file mode 100644
index 000000000..269c09156
--- /dev/null
+++ b/datafusion/core/tests/tpch-csv/region.csv
@@ -0,0 +1,2 @@
+r_regionkey,r_name,r_comment
+4,MIDDLE EAST,uickly special accounts cajole carefully blithely close requests. carefully final asymptotes haggle furiousl
diff --git a/datafusion/core/tests/tpch-csv/supplier.csv b/datafusion/core/tests/tpch-csv/supplier.csv
new file mode 100644
index 000000000..85f9aaefb
--- /dev/null
+++ b/datafusion/core/tests/tpch-csv/supplier.csv
@@ -0,0 +1,3 @@
+s_suppkey,s_name,s_address,s_nationkey,s_phone,s_acctbal,s_comment
+1,Supplier#000000001," N kD4on9OM Ipw3,gf0JBoQDd7tgrzrddZ",17,27-918-335-1736,5755.94,each slyly above the careful
+8136,Supplier#000008136,kXATyaEZOWdQC7fE43IquuR1HkKV8qx,20,30-268-895-2611,8383.6,er the carefully regular depths. pinto beans detect quickly p
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index ad0b58fac..ba6f7a96c 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -27,7 +27,7 @@ use crate::AggregateUDF;
use crate::Operator;
use crate::ScalarUDF;
use arrow::datatypes::DataType;
-use datafusion_common::Column;
+use datafusion_common::{plan_err, Column};
use datafusion_common::{DFSchema, Result};
use datafusion_common::{DataFusionError, ScalarValue};
use std::fmt;
@@ -452,6 +452,13 @@ impl Expr {
nulls_first,
}
}
+
+ pub fn try_into_col(&self) -> Result<Column> {
+ match self {
+ Expr::Column(it) => Ok(it.clone()),
+ _ => plan_err!(format!("Could not coerce '{}' into Column!", self)),
+ }
+ }
}
impl Not for Expr {
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 93c18f4b9..d42109788 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -20,7 +20,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::utils::exprlist_to_fields;
use crate::{Expr, TableProviderFilterPushDown, TableSource};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-use datafusion_common::{Column, DFSchema, DFSchemaRef, DataFusionError};
+use datafusion_common::{plan_err, Column, DFSchema, DFSchemaRef, DataFusionError};
use std::collections::HashSet;
///! Logical plan types
use std::fmt::{self, Debug, Display, Formatter};
@@ -1074,6 +1074,13 @@ impl Projection {
alias,
})
}
+
+ pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Projection> {
+ match plan {
+ LogicalPlan::Projection(it) => Ok(it),
+ _ => plan_err!("Could not coerce into Projection!"),
+ }
+ }
}
/// Aliased subquery
@@ -1103,6 +1110,15 @@ pub struct Filter {
pub input: Arc<LogicalPlan>,
}
+impl Filter {
+ pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Filter> {
+ match plan {
+ LogicalPlan::Filter(it) => Ok(it),
+ _ => plan_err!("Could not coerce into Filter!"),
+ }
+ }
+}
+
/// Window its input based on a set of window spec and window function (e.g. SUM or RANK)
#[derive(Clone)]
pub struct Window {
@@ -1287,6 +1303,15 @@ pub struct Aggregate {
pub schema: DFSchemaRef,
}
+impl Aggregate {
+ pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Aggregate> {
+ match plan {
+ LogicalPlan::Aggregate(it) => Ok(it),
+ _ => plan_err!("Could not coerce into Aggregate!"),
+ }
+ }
+}
+
/// Sorts its input according to a list of sort expressions.
#[derive(Clone)]
pub struct Sort {
@@ -1324,6 +1349,15 @@ pub struct Subquery {
pub subquery: Arc<LogicalPlan>,
}
+impl Subquery {
+ pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> {
+ match plan {
+ Expr::ScalarSubquery(it) => Ok(it),
+ _ => plan_err!("Could not coerce into ScalarSubquery!"),
+ }
+ }
+}
+
impl Debug for Subquery {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "<subquery>")
diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml
index ae493b2b0..24d2f1812 100644
--- a/datafusion/optimizer/Cargo.toml
+++ b/datafusion/optimizer/Cargo.toml
@@ -45,3 +45,8 @@ datafusion-expr = { path = "../expr", version = "10.0.0" }
datafusion-physical-expr = { path = "../physical-expr", version = "10.0.0" }
hashbrown = { version = "0.12", features = ["raw"] }
log = "^0.4"
+
+[dev-dependencies]
+ctor = "0.1.22"
+env_logger = "0.9.0"
+
diff --git a/datafusion/optimizer/src/decorrelate_scalar_subquery.rs b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs
new file mode 100644
index 000000000..d4f8372bd
--- /dev/null
+++ b/datafusion/optimizer/src/decorrelate_scalar_subquery.rs
@@ -0,0 +1,705 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::utils::{
+ exprs_to_join_cols, find_join_exprs, only_or_err, split_conjunction,
+ verify_not_disjunction,
+};
+use crate::{utils, OptimizerConfig, OptimizerRule};
+use datafusion_common::{context, plan_err, Column, Result};
+use datafusion_expr::logical_plan::{Aggregate, Filter, JoinType, Projection, Subquery};
+use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder, Operator};
+use log::debug;
+use std::sync::Arc;
+
+/// Optimizer rule for rewriting subquery filters to joins
+#[derive(Default)]
+pub struct DecorrelateScalarSubquery {}
+
+impl DecorrelateScalarSubquery {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+
+ /// Finds expressions that have a scalar subquery in them (and recurses when found)
+ ///
+ /// # Arguments
+ /// * `predicate` - A conjunction to split and search
+ /// * `optimizer_config` - For generating unique subquery aliases
+ ///
+ /// Returns a tuple (subqueries, non-subquery expressions)
+ fn extract_subquery_exprs(
+ &self,
+ predicate: &Expr,
+ optimizer_config: &mut OptimizerConfig,
+ ) -> Result<(Vec<SubqueryInfo>, Vec<Expr>)> {
+ let mut filters = vec![];
+ split_conjunction(predicate, &mut filters); // TODO: disjunctions
+
+ let mut subqueries = vec![];
+ let mut others = vec![];
+ for it in filters.iter() {
+ match it {
+ Expr::BinaryExpr { left, op, right } => {
+ let l_query = Subquery::try_from_expr(left);
+ let r_query = Subquery::try_from_expr(right);
+ if l_query.is_err() && r_query.is_err() {
+ others.push((*it).clone());
+ continue;
+ }
+ let mut recurse =
+ |q: Result<&Subquery>, expr: Expr, lhs: bool| -> Result<()> {
+ let subquery = match q {
+ Ok(subquery) => subquery,
+ _ => return Ok(()),
+ };
+ let subquery =
+ self.optimize(&*subquery.subquery, optimizer_config)?;
+ let subquery = Arc::new(subquery);
+ let subquery = Subquery { subquery };
+ let res = SubqueryInfo::new(subquery, expr, *op, lhs);
+ subqueries.push(res);
+ Ok(())
+ };
+ recurse(l_query, (**right).clone(), false)?;
+ recurse(r_query, (**left).clone(), true)?;
+ // TODO: if subquery doesn't get optimized, optimized children are lost
+ }
+ _ => others.push((*it).clone()),
+ }
+ }
+
+ Ok((subqueries, others))
+ }
+}
+
+impl OptimizerRule for DecorrelateScalarSubquery {
+ fn optimize(
+ &self,
+ plan: &LogicalPlan,
+ optimizer_config: &mut OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ match plan {
+ LogicalPlan::Filter(Filter { predicate, input }) => {
+ // Apply optimizer rule to current input
+ let optimized_input = self.optimize(input, optimizer_config)?;
+
+ let (subqueries, other_exprs) =
+ self.extract_subquery_exprs(predicate, optimizer_config)?;
+ let optimized_plan = LogicalPlan::Filter(Filter {
+ predicate: predicate.clone(),
+ input: Arc::new(optimized_input),
+ });
+ if subqueries.is_empty() {
+ // regular filter, no subquery exists clause here
+ return Ok(optimized_plan);
+ }
+
+ // iterate through all exists clauses in predicate, turning each into a join
+ let mut cur_input = (**input).clone();
+ for subquery in subqueries {
+ cur_input = optimize_scalar(
+ &subquery,
+ &cur_input,
+ &other_exprs,
+ optimizer_config,
+ )?;
+ }
+ Ok(cur_input)
+ }
+ _ => {
+ // Apply the optimization to all inputs of the plan
+ utils::optimize_children(self, plan, optimizer_config)
+ }
+ }
+ }
+
+ fn name(&self) -> &str {
+ "decorrelate_scalar_subquery"
+ }
+}
+
+/// Takes a query like:
+///
+/// ```select id from customers where balance >
+/// (select avg(total) from orders where orders.c_id = customers.id)
+/// ```
+///
+/// and optimizes it into:
+///
+/// ```select c.id from customers c
+/// inner join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
+/// where c.balance > o.val```
+///
+/// # Arguments
+///
+/// * `subqry` - The subquery portion of the `where exists` (select * from orders)
+/// * `negated` - True if the subquery is a `where not exists`
+/// * `filter_input` - The non-subquery portion (from customers)
+/// * `other_filter_exprs` - Any additional parts to the `where` expression (and c.x = y)
+/// * `optimizer_config` - Used to generate unique subquery aliases
+fn optimize_scalar(
+ query_info: &SubqueryInfo,
+ filter_input: &LogicalPlan,
+ outer_others: &[Expr],
+ optimizer_config: &mut OptimizerConfig,
+) -> Result<LogicalPlan> {
+ debug!(
+ "optimizing:\n{}",
+ query_info.query.subquery.display_indent()
+ );
+ let proj = Projection::try_from_plan(&*query_info.query.subquery)
+ .map_err(|e| context!("scalar subqueries must have a projection", e))?;
+ let proj = only_or_err(proj.expr.as_slice())
+ .map_err(|e| context!("exactly one expression should be projected", e))?;
+ let proj = Expr::Alias(Box::new(proj.clone()), "__value".to_string());
+ let sub_inputs = query_info.query.subquery.inputs();
+ let sub_input = only_or_err(sub_inputs.as_slice())
+ .map_err(|e| context!("Exactly one input is expected. Is this a join?", e))?;
+ let aggr = Aggregate::try_from_plan(sub_input)
+ .map_err(|e| context!("scalar subqueries must aggregate a value", e))?;
+ let filter = Filter::try_from_plan(&*aggr.input).map_err(|e| {
+ context!("scalar subqueries must have a filter to be correlated", e)
+ })?;
+
+ // split into filters
+ let mut subqry_filter_exprs = vec![];
+ split_conjunction(&filter.predicate, &mut subqry_filter_exprs);
+ verify_not_disjunction(&subqry_filter_exprs)?;
+
+ // Grab column names to join on
+ let (col_exprs, other_subqry_exprs) =
+ find_join_exprs(subqry_filter_exprs, filter.input.schema())?;
+ let (outer_cols, subqry_cols, join_filters) =
+ exprs_to_join_cols(&col_exprs, filter.input.schema(), false)?;
+ if join_filters.is_some() {
+ plan_err!("only joins on column equality are presently supported")?;
+ }
+
+ // Only operate if one column is present and the other closed upon from outside scope
+ let subqry_alias = format!("__sq_{}", optimizer_config.next_id());
+ let group_by: Vec<_> = subqry_cols
+ .iter()
+ .map(|it| Expr::Column(it.clone()))
+ .collect();
+
+ // build subquery side of join - the thing the subquery was querying
+ let mut subqry_plan = LogicalPlanBuilder::from((*filter.input).clone());
+ if let Some(expr) = combine_filters(&other_subqry_exprs) {
+ subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them
+ }
+
+ // project the prior projection + any correlated (and now grouped) columns
+ let proj: Vec<_> = group_by
+ .iter()
+ .cloned()
+ .chain(vec![proj].iter().cloned())
+ .collect();
+ let subqry_plan = subqry_plan
+ .aggregate(group_by, aggr.aggr_expr.clone())?
+ .project_with_alias(proj, Some(subqry_alias.clone()))?
+ .build()?;
+
+ // qualify the join columns for outside the subquery
+ let subqry_cols: Vec<_> = subqry_cols
+ .iter()
+ .map(|it| Column {
+ relation: Some(subqry_alias.clone()),
+ name: it.name.clone(),
+ })
+ .collect();
+ let join_keys = (outer_cols, subqry_cols);
+
+ // join our sub query into the main plan
+ let new_plan = LogicalPlanBuilder::from(filter_input.clone());
+ let mut new_plan = if join_keys.0.is_empty() {
+ // if not correlated, group down to 1 row and cross join on that (preserving row count)
+ new_plan.cross_join(&subqry_plan)?
+ } else {
+ // inner join if correlated, grouping by the join keys so we don't change row count
+ new_plan.join(&subqry_plan, JoinType::Inner, join_keys, None)?
+ };
+
+ // restore where in condition
+ let qry_expr = Box::new(Expr::Column(Column {
+ relation: Some(subqry_alias),
+ name: "__value".to_string(),
+ }));
+ let filter_expr = if query_info.expr_on_left {
+ Expr::BinaryExpr {
+ left: Box::new(query_info.expr.clone()),
+ op: query_info.op,
+ right: qry_expr,
+ }
+ } else {
+ Expr::BinaryExpr {
+ left: qry_expr,
+ op: query_info.op,
+ right: Box::new(query_info.expr.clone()),
+ }
+ };
+ new_plan = new_plan.filter(filter_expr)?;
+
+ // if the main query had additional expressions, restore them
+ if let Some(expr) = combine_filters(outer_others) {
+ new_plan = new_plan.filter(expr)?
+ }
+ let new_plan = new_plan.build()?;
+
+ Ok(new_plan)
+}
+
+struct SubqueryInfo {
+ query: Subquery,
+ expr: Expr,
+ op: Operator,
+ expr_on_left: bool,
+}
+
+impl SubqueryInfo {
+ pub fn new(query: Subquery, expr: Expr, op: Operator, expr_on_left: bool) -> Self {
+ Self {
+ query,
+ expr,
+ op,
+ expr_on_left,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::test::*;
+ use datafusion_common::Result;
+ use datafusion_expr::{
+ col, lit, logical_plan::LogicalPlanBuilder, max, min, scalar_subquery, sum,
+ };
+ use std::ops::Add;
+
+ #[cfg(test)]
+ #[ctor::ctor]
+ fn init() {
+ let _ = env_logger::try_init();
+ }
+
+ /// Test multiple correlated subqueries
+ #[test]
+ fn multiple_subqueries() -> Result<()> {
+ let orders = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("orders.o_custkey").eq(col("customer.c_custkey")))?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(
+ lit(1)
+ .lt(scalar_subquery(orders.clone()))
+ .and(lit(1).lt(scalar_subquery(orders))),
+ )?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: Int32(1) < #__sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N, o_custkey:Int64, __value:Int64;N]
+ Inner Join: #customer.c_custkey = #__sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N, o_custkey:Int64, __value:Int64;N]
+ Filter: Int32(1) < #__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+ Inner Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey, #MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
+ Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[MAX(#orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ Projection: #orders.o_custkey, #MAX(orders.o_custkey) AS __value, alias=__sq_2 [o_custkey:Int64, __value:Int64;N]
+ Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[MAX(#orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test recursive correlated subqueries
+ #[test]
+ fn recursive_subqueries() -> Result<()> {
+ let lineitem = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
+ .filter(col("lineitem.l_orderkey").eq(col("orders.o_orderkey")))?
+ .aggregate(
+ Vec::<Expr>::new(),
+ vec![sum(col("lineitem.l_extendedprice"))],
+ )?
+ .project(vec![sum(col("lineitem.l_extendedprice"))])?
+ .build()?,
+ );
+
+ let orders = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ col("orders.o_custkey")
+ .eq(col("customer.c_custkey"))
+ .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
+ )?
+ .aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
+ .project(vec![sum(col("orders.o_totalprice"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_acctbal < #__sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]
+ Inner Join: #customer.c_custkey = #__sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, alias=__sq_2 [o_custkey:Int64, __value:Float64;N]
+ Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[SUM(#orders.o_totalprice)]] [o_custkey:Int64, SUM(orders.o_totalprice):Float64;N]
+ Filter: #orders.o_totalprice < #__sq_1.__value [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]
+ Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ Projection: #lineitem.l_orderkey, #SUM(lineitem.l_extendedprice) AS __value, alias=__sq_1 [l_orderkey:Int64, __value:Float64;N]
+ Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[SUM(#lineitem.l_extendedprice)]] [l_orderkey:Int64, SUM(lineitem.l_extendedprice):Float64;N]
+ TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#;
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery filter with additional subquery filters
+ #[test]
+ fn scalar_subquery_with_subquery_filters() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(col("orders.o_custkey"))
+ .and(col("o_orderkey").eq(lit(1))),
+ )?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+ Inner Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey, #MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
+ Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[MAX(#orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
+ Filter: #orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery with no columns in schema
+ #[test]
+ fn scalar_subquery_no_cols() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("customer.c_custkey")))?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // it will optimize, but fail for the same reason the unoptimized query would
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
+ CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #MAX(orders.o_custkey) AS __value, alias=__sq_1 [__value:Int64;N]
+ Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]
+ Filter: #customer.c_custkey = #customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for scalar subquery with both columns in schema
+ #[test]
+ fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
+ CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #MAX(orders.o_custkey) AS __value, alias=__sq_1 [__value:Int64;N]
+ Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]
+ Filter: #orders.o_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery not equal
+ #[test]
+ fn scalar_subquery_where_not_eq() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").not_eq(col("orders.o_custkey")))?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"only joins on column equality are presently supported"#;
+
+ assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery less than
+ #[test]
+ fn scalar_subquery_where_less_than() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").lt(col("orders.o_custkey")))?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"can't optimize < column comparison"#;
+ assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery filter with subquery disjunction
+ #[test]
+ fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(col("orders.o_custkey"))
+ .or(col("o_orderkey").eq(lit(1))),
+ )?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Optimizing disjunctions not supported!"#;
+ assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar without projection
+ #[test]
+ fn scalar_subquery_no_projection() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"scalar subqueries must have a projection"#;
+ assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar expressions
+ #[test]
+ #[ignore]
+ fn scalar_subquery_project_expr() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey")).add(lit(1))])?
+ .build()?,
+ );
+ /*
+ Error: SchemaError(FieldNotFound { qualifier: Some("orders"), name: "o_custkey", valid_fields: Some(["MAX(orders.o_custkey)"]) })
+ */
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#""#;
+
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery multiple projected columns
+ #[test]
+ fn scalar_subquery_multi_col() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(scalar_subquery(sq))
+ .and(col("c_custkey").eq(lit(1))),
+ )?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"exactly one expression should be projected"#;
+ assert_optimizer_err(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery filter with additional filters
+ #[test]
+ fn scalar_subquery_additional_filters() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(scalar_subquery(sq))
+ .and(col("c_custkey").eq(lit(1))),
+ )?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+ Filter: #customer.c_custkey = #__sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+ Inner Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey, #MAX(orders.o_custkey) AS __value, alias=__sq_1 [o_custkey:Int64, __value:Int64;N]
+ Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[MAX(#orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery filter with disjustions
+ #[test]
+ fn scalar_subquery_disjunction() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
+ .project(vec![max(col("orders.o_custkey"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(scalar_subquery(sq))
+ .or(col("customer.c_custkey").eq(lit(1))),
+ )?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // unoptimized plan because we don't support disjunctions yet
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_custkey = (<subquery>) OR #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
+ Subquery: [MAX(orders.o_custkey):Int64;N]
+ Projection: #MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]
+ Aggregate: groupBy=[[]], aggr=[[MAX(#orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]
+ Filter: #customer.c_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated scalar subquery filter
+ #[test]
+ fn exists_subquery_correlated() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
+ .filter(col("test.a").eq(col("sq.a")))?
+ .aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
+ .project(vec![min(col("c"))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
+ .filter(col("test.c").lt(scalar_subquery(sq)))?
+ .project(vec![col("test.c")])?
+ .build()?;
+
+ let expected = r#"Projection: #test.c [c:UInt32]
+ Filter: #test.c < #__sq_1.__value [a:UInt32, b:UInt32, c:UInt32, a:UInt32, __value:UInt32;N]
+ Inner Join: #test.a = #__sq_1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, __value:UInt32;N]
+ TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+ Projection: #sq.a, #MIN(sq.c) AS __value, alias=__sq_1 [a:UInt32, __value:UInt32;N]
+ Aggregate: groupBy=[[#sq.a]], aggr=[[MIN(#sq.c)]] [a:UInt32, MIN(sq.c):UInt32;N]
+ TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#;
+
+ assert_optimized_plan_eq(&DecorrelateScalarSubquery::new(), &plan, expected);
+ Ok(())
+ }
+}
diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs
new file mode 100644
index 000000000..2c25bcbb2
--- /dev/null
+++ b/datafusion/optimizer/src/decorrelate_where_exists.rs
@@ -0,0 +1,557 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::utils::{
+ exprs_to_join_cols, find_join_exprs, only_or_err, split_conjunction,
+ verify_not_disjunction,
+};
+use crate::{utils, OptimizerConfig, OptimizerRule};
+use datafusion_common::{context, plan_err};
+use datafusion_expr::logical_plan::{Filter, JoinType, Subquery};
+use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder};
+use std::sync::Arc;
+
+/// Optimizer rule for rewriting subquery filters to joins
+#[derive(Default)]
+pub struct DecorrelateWhereExists {}
+
+impl DecorrelateWhereExists {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+
+ /// Finds expressions that have a where in subquery (and recurses when found)
+ ///
+ /// # Arguments
+ ///
+ /// * `predicate` - A conjunction to split and search
+ /// * `optimizer_config` - For generating unique subquery aliases
+ ///
+ /// Returns a tuple (subqueries, non-subquery expressions)
+ fn extract_subquery_exprs(
+ &self,
+ predicate: &Expr,
+ optimizer_config: &mut OptimizerConfig,
+ ) -> datafusion_common::Result<(Vec<SubqueryInfo>, Vec<Expr>)> {
+ let mut filters = vec![];
+ split_conjunction(predicate, &mut filters);
+
+ let mut subqueries = vec![];
+ let mut others = vec![];
+ for it in filters.iter() {
+ match it {
+ Expr::Exists { subquery, negated } => {
+ let subquery =
+ self.optimize(&*subquery.subquery, optimizer_config)?;
+ let subquery = Arc::new(subquery);
+ let subquery = Subquery { subquery };
+ let subquery = SubqueryInfo::new(subquery.clone(), *negated);
+ subqueries.push(subquery);
+ }
+ _ => others.push((*it).clone()),
+ }
+ }
+
+ Ok((subqueries, others))
+ }
+}
+
+impl OptimizerRule for DecorrelateWhereExists {
+ fn optimize(
+ &self,
+ plan: &LogicalPlan,
+ optimizer_config: &mut OptimizerConfig,
+ ) -> datafusion_common::Result<LogicalPlan> {
+ match plan {
+ LogicalPlan::Filter(Filter {
+ predicate,
+ input: filter_input,
+ }) => {
+ // Apply optimizer rule to current input
+ let optimized_input = self.optimize(filter_input, optimizer_config)?;
+
+ let (subqueries, other_exprs) =
+ self.extract_subquery_exprs(predicate, optimizer_config)?;
+ let optimized_plan = LogicalPlan::Filter(Filter {
+ predicate: predicate.clone(),
+ input: Arc::new(optimized_input),
+ });
+ if subqueries.is_empty() {
+ // regular filter, no subquery exists clause here
+ return Ok(optimized_plan);
+ }
+
+ // iterate through all exists clauses in predicate, turning each into a join
+ let mut cur_input = (**filter_input).clone();
+ for subquery in subqueries {
+ cur_input = optimize_exists(&subquery, &cur_input, &other_exprs)?;
+ }
+ Ok(cur_input)
+ }
+ _ => {
+ // Apply the optimization to all inputs of the plan
+ utils::optimize_children(self, plan, optimizer_config)
+ }
+ }
+ }
+
+ fn name(&self) -> &str {
+ "decorrelate_where_exists"
+ }
+}
+
+/// Takes a query like:
+///
+/// ```select c.id from customers c where exists (select * from orders o where o.c_id = c.id)```
+///
+/// and optimizes it into:
+///
+/// ```select c.id from customers c
+/// inner join (select o.c_id from orders o group by o.c_id) o on o.c_id = c.c_id```
+///
+/// # Arguments
+///
+/// * subqry - The subquery portion of the `where exists` (select * from orders)
+/// * negated - True if the subquery is a `where not exists`
+/// * filter_input - The non-subquery portion (from customers)
+/// * outer_exprs - Any additional parts to the `where` expression (and c.x = y)
+fn optimize_exists(
+ query_info: &SubqueryInfo,
+ outer_input: &LogicalPlan,
+ outer_other_exprs: &[Expr],
+) -> datafusion_common::Result<LogicalPlan> {
+ let subqry_inputs = query_info.query.subquery.inputs();
+ let subqry_input = only_or_err(subqry_inputs.as_slice())
+ .map_err(|e| context!("single expression projection required", e))?;
+ let subqry_filter = Filter::try_from_plan(subqry_input)
+ .map_err(|e| context!("cannot optimize non-correlated subquery", e))?;
+
+ // split into filters
+ let mut subqry_filter_exprs = vec![];
+ split_conjunction(&subqry_filter.predicate, &mut subqry_filter_exprs);
+ verify_not_disjunction(&subqry_filter_exprs)?;
+
+ // Grab column names to join on
+ let (col_exprs, other_subqry_exprs) =
+ find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())?;
+ let (outer_cols, subqry_cols, join_filters) =
+ exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)?;
+ if subqry_cols.is_empty() || outer_cols.is_empty() {
+ plan_err!("cannot optimize non-correlated subquery")?;
+ }
+
+ // build subquery side of join - the thing the subquery was querying
+ let mut subqry_plan = LogicalPlanBuilder::from((*subqry_filter.input).clone());
+ if let Some(expr) = combine_filters(&other_subqry_exprs) {
+ subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them
+ }
+ let subqry_plan = subqry_plan.build()?;
+
+ let join_keys = (subqry_cols, outer_cols);
+
+ // join our sub query into the main plan
+ let join_type = match query_info.negated {
+ true => JoinType::Anti,
+ false => JoinType::Semi,
+ };
+ let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join(
+ &subqry_plan,
+ join_type,
+ join_keys,
+ join_filters,
+ )?;
+ if let Some(expr) = combine_filters(outer_other_exprs) {
+ new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them
+ }
+
+ let result = new_plan.build()?;
+ Ok(result)
+}
+
+struct SubqueryInfo {
+ query: Subquery,
+ negated: bool,
+}
+
+impl SubqueryInfo {
+ pub fn new(query: Subquery, negated: bool) -> Self {
+ Self { query, negated }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::test::*;
+ use datafusion_common::Result;
+ use datafusion_expr::{
+ col, exists, lit, logical_plan::LogicalPlanBuilder, not_exists,
+ };
+ use std::ops::Add;
+
+ /// Test for multiple exists subqueries in the same filter expression
+ #[test]
+ fn multiple_subqueries() -> Result<()> {
+ let orders = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("orders.o_custkey").eq(col("customer.c_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(orders.clone()).and(exists(orders)))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8]
+ Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test recursive correlated subqueries
+ #[test]
+ fn recursive_subqueries() -> Result<()> {
+ let lineitem = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
+ .filter(col("lineitem.l_orderkey").eq(col("orders.o_orderkey")))?
+ .project(vec![col("lineitem.l_orderkey")])?
+ .build()?,
+ );
+
+ let orders = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ exists(lineitem)
+ .and(col("orders.o_custkey").eq(col("customer.c_custkey"))),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(orders))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Semi Join: #orders.o_orderkey = #lineitem.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated exists subquery filter with additional subquery filters
+ #[test]
+ fn exists_subquery_with_subquery_filters() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(col("orders.o_custkey"))
+ .and(col("o_orderkey").eq(lit(1))),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Filter: #orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated exists subquery with no columns in schema
+ #[test]
+ fn exists_subquery_no_cols() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("customer.c_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"cannot optimize non-correlated subquery"#;
+
+ assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for exists subquery with both columns in schema
+ #[test]
+ fn exists_subquery_with_no_correlated_cols() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"cannot optimize non-correlated subquery"#;
+
+ assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated exists subquery not equal
+ #[test]
+ fn exists_subquery_where_not_eq() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").not_eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"cannot optimize non-correlated subquery"#;
+
+ assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated exists subquery less than
+ #[test]
+ fn exists_subquery_where_less_than() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").lt(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"can't optimize < column comparison"#;
+
+ assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated exists subquery filter with subquery disjunction
+ #[test]
+ fn exists_subquery_with_subquery_disjunction() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(col("orders.o_custkey"))
+ .or(col("o_orderkey").eq(lit(1))),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Optimizing disjunctions not supported!"#;
+
+ assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated exists without projection
+ #[test]
+ fn exists_subquery_no_projection() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"cannot optimize non-correlated subquery"#;
+
+ assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated exists expressions
+ #[test]
+ fn exists_subquery_project_expr() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey").add(lit(1))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // Doesn't matter we projected an expression, just that we returned a result
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated exists subquery filter with additional filters
+ #[test]
+ fn should_support_additional_filters() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq).and(col("c_custkey").eq(lit(1))))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
+ Semi Join: #customer.c_custkey = #orders.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated exists subquery filter with disjustions
+ #[test]
+ fn exists_subquery_disjunction() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(exists(sq).or(col("customer.c_custkey").eq(lit(1))))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // not optimized
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: EXISTS (<subquery>) OR #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
+ Subquery: [o_custkey:Int64]
+ Projection: #orders.o_custkey [o_custkey:Int64]
+ Filter: #customer.c_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated EXISTS subquery filter
+ #[test]
+ fn exists_subquery_correlated() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
+ .filter(col("test.a").eq(col("sq.a")))?
+ .project(vec![col("c")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
+ .filter(exists(sq))?
+ .project(vec![col("test.c")])?
+ .build()?;
+
+ let expected = r#"Projection: #test.c [c:UInt32]
+ Semi Join: #test.a = #sq.a [a:UInt32, b:UInt32, c:UInt32]
+ TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+ TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for single exists subquery filter
+ #[test]
+ fn exists_subquery_simple() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(exists(test_subquery_with_name("sq")?))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "cannot optimize non-correlated subquery";
+
+ assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for single NOT exists subquery filter
+ #[test]
+ fn not_exists_subquery_simple() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(not_exists(test_subquery_with_name("sq")?))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = "cannot optimize non-correlated subquery";
+
+ assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
+ Ok(())
+ }
+}
diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs
new file mode 100644
index 000000000..f90d94d8c
--- /dev/null
+++ b/datafusion/optimizer/src/decorrelate_where_in.rs
@@ -0,0 +1,693 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::utils::{
+ alias_cols, exprs_to_join_cols, find_join_exprs, merge_cols, only_or_err,
+ split_conjunction, swap_table, verify_not_disjunction,
+};
+use crate::{utils, OptimizerConfig, OptimizerRule};
+use datafusion_common::context;
+use datafusion_expr::logical_plan::{Filter, JoinType, Projection, Subquery};
+use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder};
+use log::debug;
+use std::sync::Arc;
+
+#[derive(Default)]
+pub struct DecorrelateWhereIn {}
+
+impl DecorrelateWhereIn {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+
+ /// Finds expressions that have a where in subquery (and recurses when found)
+ ///
+ /// # Arguments
+ ///
+ /// * `predicate` - A conjunction to split and search
+ /// * `optimizer_config` - For generating unique subquery aliases
+ ///
+ /// Returns a tuple (subqueries, non-subquery expressions)
+ fn extract_subquery_exprs(
+ &self,
+ predicate: &Expr,
+ optimizer_config: &mut OptimizerConfig,
+ ) -> datafusion_common::Result<(Vec<SubqueryInfo>, Vec<Expr>)> {
+ let mut filters = vec![];
+ split_conjunction(predicate, &mut filters); // TODO: disjunctions
+
+ let mut subqueries = vec![];
+ let mut others = vec![];
+ for it in filters.iter() {
+ match it {
+ Expr::InSubquery {
+ expr,
+ subquery,
+ negated,
+ } => {
+ let subquery =
+ self.optimize(&*subquery.subquery, optimizer_config)?;
+ let subquery = Arc::new(subquery);
+ let subquery = Subquery { subquery };
+ let subquery =
+ SubqueryInfo::new(subquery.clone(), (**expr).clone(), *negated);
+ subqueries.push(subquery);
+ // TODO: if subquery doesn't get optimized, optimized children are lost
+ }
+ _ => others.push((*it).clone()),
+ }
+ }
+
+ Ok((subqueries, others))
+ }
+}
+
+impl OptimizerRule for DecorrelateWhereIn {
+ fn optimize(
+ &self,
+ plan: &LogicalPlan,
+ optimizer_config: &mut OptimizerConfig,
+ ) -> datafusion_common::Result<LogicalPlan> {
+ match plan {
+ LogicalPlan::Filter(Filter {
+ predicate,
+ input: filter_input,
+ }) => {
+ // Apply optimizer rule to current input
+ let optimized_input = self.optimize(filter_input, optimizer_config)?;
+
+ let (subqueries, other_exprs) =
+ self.extract_subquery_exprs(predicate, optimizer_config)?;
+ let optimized_plan = LogicalPlan::Filter(Filter {
+ predicate: predicate.clone(),
+ input: Arc::new(optimized_input),
+ });
+ if subqueries.is_empty() {
+ // regular filter, no subquery exists clause here
+ return Ok(optimized_plan);
+ }
+
+ // iterate through all exists clauses in predicate, turning each into a join
+ let mut cur_input = (**filter_input).clone();
+ for subquery in subqueries {
+ cur_input = optimize_where_in(
+ &subquery,
+ &cur_input,
+ &other_exprs,
+ optimizer_config,
+ )?;
+ }
+ Ok(cur_input)
+ }
+ _ => {
+ // Apply the optimization to all inputs of the plan
+ utils::optimize_children(self, plan, optimizer_config)
+ }
+ }
+ }
+
+ fn name(&self) -> &str {
+ "decorrelate_where_in"
+ }
+}
+
+fn optimize_where_in(
+ query_info: &SubqueryInfo,
+ outer_input: &LogicalPlan,
+ outer_other_exprs: &[Expr],
+ optimizer_config: &mut OptimizerConfig,
+) -> datafusion_common::Result<LogicalPlan> {
+ let proj = Projection::try_from_plan(&*query_info.query.subquery)
+ .map_err(|e| context!("a projection is required", e))?;
+ let mut subqry_input = proj.input.clone();
+ let proj = only_or_err(proj.expr.as_slice())
+ .map_err(|e| context!("single expression projection required", e))?;
+ let subquery_col = proj
+ .try_into_col()
+ .map_err(|e| context!("single column projection required", e))?;
+ let outer_col = query_info
+ .where_in_expr
+ .try_into_col()
+ .map_err(|e| context!("column comparison required", e))?;
+
+ // If subquery is correlated, grab necessary information
+ let mut subqry_cols = vec![];
+ let mut outer_cols = vec![];
+ let mut join_filters = None;
+ let mut other_subqry_exprs = vec![];
+ if let LogicalPlan::Filter(subqry_filter) = (*subqry_input).clone() {
+ // split into filters
+ let mut subqry_filter_exprs = vec![];
+ split_conjunction(&subqry_filter.predicate, &mut subqry_filter_exprs);
+ verify_not_disjunction(&subqry_filter_exprs)?;
+
+ // Grab column names to join on
+ let (col_exprs, other_exprs) =
+ find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())
+ .map_err(|e| context!("column correlation not found", e))?;
+ if !col_exprs.is_empty() {
+ // it's correlated
+ subqry_input = subqry_filter.input.clone();
+ (outer_cols, subqry_cols, join_filters) =
+ exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)
+ .map_err(|e| context!("column correlation not found", e))?;
+ other_subqry_exprs = other_exprs;
+ }
+ }
+
+ let (subqry_cols, outer_cols) =
+ merge_cols((&[subquery_col], &subqry_cols), (&[outer_col], &outer_cols));
+
+ // build subquery side of join - the thing the subquery was querying
+ let subqry_alias = format!("__sq_{}", optimizer_config.next_id());
+ let mut subqry_plan = LogicalPlanBuilder::from((*subqry_input).clone());
+ if let Some(expr) = combine_filters(&other_subqry_exprs) {
+ // if the subquery had additional expressions, restore them
+ subqry_plan = subqry_plan.filter(expr)?
+ }
+ let projection = alias_cols(&subqry_cols);
+ let subqry_plan = subqry_plan
+ .project_with_alias(projection, Some(subqry_alias.clone()))?
+ .build()?;
+ debug!("subquery plan:\n{}", subqry_plan.display_indent());
+
+ // qualify the join columns for outside the subquery
+ let subqry_cols = swap_table(&subqry_alias, &subqry_cols);
+ let join_keys = (outer_cols, subqry_cols);
+
+ // join our sub query into the main plan
+ let join_type = match query_info.negated {
+ true => JoinType::Anti,
+ false => JoinType::Semi,
+ };
+ let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join(
+ &subqry_plan,
+ join_type,
+ join_keys,
+ join_filters,
+ )?;
+ if let Some(expr) = combine_filters(outer_other_exprs) {
+ new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them
+ }
+ let new_plan = new_plan.build()?;
+
+ debug!("where in optimized:\n{}", new_plan.display_indent());
+ Ok(new_plan)
+}
+
+struct SubqueryInfo {
+ query: Subquery,
+ where_in_expr: Expr,
+ negated: bool,
+}
+
+impl SubqueryInfo {
+ pub fn new(query: Subquery, expr: Expr, negated: bool) -> Self {
+ Self {
+ query,
+ where_in_expr: expr,
+ negated,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::test::*;
+ use datafusion_common::Result;
+ use datafusion_expr::{
+ col, in_subquery, lit, logical_plan::LogicalPlanBuilder, not_in_subquery,
+ };
+ use std::ops::Add;
+
+ #[cfg(test)]
+ #[ctor::ctor]
+ fn init() {
+ let _ = env_logger::try_init();
+ }
+
+ /// Test multiple correlated subqueries
+ /// See subqueries.rs where_in_multiple()
+ #[test]
+ fn multiple_subqueries() -> Result<()> {
+ let orders = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("orders.o_custkey").eq(col("customer.c_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(
+ in_subquery(col("customer.c_custkey"), orders.clone())
+ .and(in_subquery(col("customer.c_custkey"), orders)),
+ )?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+ debug!("plan to optimize:\n{}", plan.display_indent());
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #__sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]
+ Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ Projection: #orders.o_custkey AS o_custkey, alias=__sq_2 [o_custkey:Int64]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test recursive correlated subqueries
+ /// See subqueries.rs where_in_recursive()
+ #[test]
+ fn recursive_subqueries() -> Result<()> {
+ let lineitem = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
+ .filter(col("lineitem.l_orderkey").eq(col("orders.o_orderkey")))?
+ .project(vec![col("lineitem.l_orderkey")])?
+ .build()?,
+ );
+
+ let orders = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ in_subquery(col("orders.o_orderkey"), lineitem)
+ .and(col("orders.o_custkey").eq(col("customer.c_custkey"))),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey"), orders))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #__sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey AS o_custkey, alias=__sq_2 [o_custkey:Int64]
+ Semi Join: #orders.o_orderkey = #__sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ Projection: #lineitem.l_orderkey AS l_orderkey, alias=__sq_1 [l_orderkey:Int64]
+ TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery filter with additional subquery filters
+ #[test]
+ fn in_subquery_with_subquery_filters() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(col("orders.o_custkey"))
+ .and(col("o_orderkey").eq(lit(1))),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey"), sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64]
+ Filter: #orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery with no columns in schema
+ #[test]
+ fn in_subquery_no_cols() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("customer.c_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey"), sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // Query will fail, but we can still transform the plan
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64]
+ Filter: #customer.c_custkey = #customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for IN subquery with both columns in schema
+ #[test]
+ fn in_subquery_with_no_correlated_cols() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey"), sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64]
+ Filter: #orders.o_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery not equal
+ #[test]
+ fn in_subquery_where_not_eq() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").not_eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey"), sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Semi Join: #customer.c_custkey = #__sq_1.o_custkey Filter: #customer.c_custkey != #orders.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery less than
+ #[test]
+ fn in_subquery_where_less_than() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").lt(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey"), sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // can't optimize on arbitrary expressions (yet)
+ assert_optimizer_err(
+ &DecorrelateWhereIn::new(),
+ &plan,
+ "column correlation not found",
+ );
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery filter with subquery disjunction
+ #[test]
+ fn in_subquery_with_subquery_disjunction() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(
+ col("customer.c_custkey")
+ .eq(col("orders.o_custkey"))
+ .or(col("o_orderkey").eq(lit(1))),
+ )?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey"), sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ assert_optimizer_err(
+ &DecorrelateWhereIn::new(),
+ &plan,
+ "Optimizing disjunctions not supported!",
+ );
+ Ok(())
+ }
+
+ /// Test for correlated IN without projection
+ #[test]
+ fn in_subquery_no_projection() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey"), sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // Maybe okay if the table only has a single column?
+ assert_optimizer_err(
+ &DecorrelateWhereIn::new(),
+ &plan,
+ "a projection is required",
+ );
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery join on expression
+ #[test]
+ fn in_subquery_join_expr() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey").add(lit(1)), sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // TODO: support join on expression
+ assert_optimizer_err(
+ &DecorrelateWhereIn::new(),
+ &plan,
+ "column comparison required",
+ );
+ Ok(())
+ }
+
+ /// Test for correlated IN expressions
+ #[test]
+ fn in_subquery_project_expr() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey").add(lit(1))])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(in_subquery(col("customer.c_custkey"), sq))?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // TODO: support join on expressions?
+ assert_optimizer_err(
+ &DecorrelateWhereIn::new(),
+ &plan,
+ "single column projection required",
+ );
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery multiple projected columns
+ #[test]
+ fn in_subquery_multi_col() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(
+ in_subquery(col("customer.c_custkey"), sq)
+ .and(col("c_custkey").eq(lit(1))),
+ )?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ assert_optimizer_err(
+ &DecorrelateWhereIn::new(),
+ &plan,
+ "single expression projection required",
+ );
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery filter with additional filters
+ #[test]
+ fn should_support_additional_filters() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(
+ in_subquery(col("customer.c_custkey"), sq)
+ .and(col("c_custkey").eq(lit(1))),
+ )?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
+ Semi Join: #customer.c_custkey = #__sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]
+ Projection: #orders.o_custkey AS o_custkey, alias=__sq_1 [o_custkey:Int64]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery filter with disjustions
+ #[test]
+ fn in_subquery_disjunction() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(scan_tpch_table("orders"))
+ .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
+ .project(vec![col("orders.o_custkey")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
+ .filter(
+ in_subquery(col("customer.c_custkey"), sq)
+ .or(col("customer.c_custkey").eq(lit(1))),
+ )?
+ .project(vec![col("customer.c_custkey")])?
+ .build()?;
+
+ // TODO: support disjunction - for now expect unaltered plan
+ let expected = r#"Projection: #customer.c_custkey [c_custkey:Int64]
+ Filter: #customer.c_custkey IN (<subquery>) OR #customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]
+ Subquery: [o_custkey:Int64]
+ Projection: #orders.o_custkey [o_custkey:Int64]
+ Filter: #customer.c_custkey = #orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
+ TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for correlated IN subquery filter
+ #[test]
+ fn in_subquery_correlated() -> Result<()> {
+ let sq = Arc::new(
+ LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
+ .filter(col("test.a").eq(col("sq.a")))?
+ .project(vec![col("c")])?
+ .build()?,
+ );
+
+ let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
+ .filter(in_subquery(col("c"), sq))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = r#"Projection: #test.b [b:UInt32]
+ Semi Join: #test.c = #__sq_1.c, #test.a = #__sq_1.a [a:UInt32, b:UInt32, c:UInt32]
+ TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+ Projection: #sq.c AS c, #sq.a AS a, alias=__sq_1 [c:UInt32, a:UInt32]
+ TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for single IN subquery filter
+ #[test]
+ fn in_subquery_simple() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(in_subquery(col("c"), test_subquery_with_name("sq")?))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = r#"Projection: #test.b [b:UInt32]
+ Semi Join: #test.c = #__sq_1.c [a:UInt32, b:UInt32, c:UInt32]
+ TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+ Projection: #sq.c AS c, alias=__sq_1 [c:UInt32]
+ TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+
+ /// Test for single NOT IN subquery filter
+ #[test]
+ fn not_in_subquery_simple() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(not_in_subquery(col("c"), test_subquery_with_name("sq")?))?
+ .project(vec![col("test.b")])?
+ .build()?;
+
+ let expected = r#"Projection: #test.b [b:UInt32]
+ Anti Join: #test.c = #__sq_1.c [a:UInt32, b:UInt32, c:UInt32]
+ TableScan: test [a:UInt32, b:UInt32, c:UInt32]
+ Projection: #sq.c AS c, alias=__sq_1 [c:UInt32]
+ TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#;
+
+ assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected);
+ Ok(())
+ }
+}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index a6b7cfcbb..588903ad0 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -16,6 +16,9 @@
// under the License.
pub mod common_subexpr_eliminate;
+pub mod decorrelate_scalar_subquery;
+pub mod decorrelate_where_exists;
+pub mod decorrelate_where_in;
pub mod eliminate_filter;
pub mod eliminate_limit;
pub mod expr_simplifier;
diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs
index 86e12bc30..fc7d0bd8a 100644
--- a/datafusion/optimizer/src/test/mod.rs
+++ b/datafusion/optimizer/src/test/mod.rs
@@ -15,9 +15,11 @@
// specific language governing permissions and limitations
// under the License.
+use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
-use datafusion_expr::{logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder};
+use datafusion_expr::{col, logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder};
+use std::sync::Arc;
pub mod user_defined;
@@ -54,3 +56,76 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) {
.collect();
assert_eq!(actual, expected);
}
+
+pub fn test_subquery_with_name(name: &str) -> Result<Arc<LogicalPlan>> {
+ let table_scan = test_table_scan_with_name(name)?;
+ Ok(Arc::new(
+ LogicalPlanBuilder::from(table_scan)
+ .project(vec![col("c")])?
+ .build()?,
+ ))
+}
+
+pub fn scan_tpch_table(table: &str) -> LogicalPlan {
+ let schema = Arc::new(get_tpch_table_schema(table));
+ table_scan(Some(table), &schema, None)
+ .unwrap()
+ .build()
+ .unwrap()
+}
+
+pub fn get_tpch_table_schema(table: &str) -> Schema {
+ match table {
+ "customer" => Schema::new(vec![
+ Field::new("c_custkey", DataType::Int64, false),
+ Field::new("c_name", DataType::Utf8, false),
+ ]),
+
+ "orders" => Schema::new(vec![
+ Field::new("o_orderkey", DataType::Int64, false),
+ Field::new("o_custkey", DataType::Int64, false),
+ Field::new("o_orderstatus", DataType::Utf8, false),
+ Field::new("o_totalprice", DataType::Float64, true),
+ ]),
+
+ "lineitem" => Schema::new(vec![
+ Field::new("l_orderkey", DataType::Int64, false),
+ Field::new("l_partkey", DataType::Int64, false),
+ Field::new("l_suppkey", DataType::Int64, false),
+ Field::new("l_linenumber", DataType::Int32, false),
+ Field::new("l_quantity", DataType::Float64, false),
+ Field::new("l_extendedprice", DataType::Float64, false),
+ ]),
+
+ _ => unimplemented!("Table: {}", table),
+ }
+}
+
+pub fn assert_optimized_plan_eq(
+ rule: &dyn OptimizerRule,
+ plan: &LogicalPlan,
+ expected: &str,
+) {
+ let optimized_plan = rule
+ .optimize(plan, &mut OptimizerConfig::new())
+ .expect("failed to optimize plan");
+ let formatted_plan = format!("{}", optimized_plan.display_indent_schema());
+ assert_eq!(formatted_plan, expected);
+}
+
+pub fn assert_optimizer_err(
+ rule: &dyn OptimizerRule,
+ plan: &LogicalPlan,
+ expected: &str,
+) {
+ let res = rule.optimize(plan, &mut OptimizerConfig::new());
+ match res {
+ Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"),
+ Err(ref e) => {
+ let actual = format!("{}", e);
+ if expected.is_empty() || !actual.contains(expected) {
+ assert_eq!(actual, expected)
+ }
+ }
+ }
+}
diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs
index cd70c5091..41c75d689 100644
--- a/datafusion/optimizer/src/utils.rs
+++ b/datafusion/optimizer/src/utils.rs
@@ -19,12 +19,15 @@
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
+use datafusion_common::{plan_err, Column, DFSchemaRef};
+use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use datafusion_expr::{
- and,
+ and, col, combine_filters,
logical_plan::{Filter, LogicalPlan},
utils::from_plan,
Expr, Operator,
};
+use std::collections::HashSet;
use std::sync::Arc;
/// Convenience rule for writing optimizers: recursively invoke
@@ -65,6 +68,40 @@ pub fn split_conjunction<'a>(predicate: &'a Expr, predicates: &mut Vec<&'a Expr>
}
}
+/// Recursively scans a slice of expressions for any `Or` operators
+///
+/// # Arguments
+///
+/// * `predicates` - the expressions to scan
+///
+/// # Return value
+///
+/// A PlanError if a disjunction is found
+pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> {
+ struct DisjunctionVisitor {}
+
+ impl ExpressionVisitor for DisjunctionVisitor {
+ fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
+ match expr {
+ Expr::BinaryExpr {
+ left: _,
+ op: Operator::Or,
+ right: _,
+ } => {
+ plan_err!("Optimizing disjunctions not supported!")
+ }
+ _ => Ok(Recursion::Continue(self)),
+ }
+ }
+ }
+
+ for predicate in predicates.iter() {
+ predicate.accept(DisjunctionVisitor {})?;
+ }
+
+ Ok(())
+}
+
/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with
/// its predicate be all `predicates` ANDed.
pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan {
@@ -82,6 +119,202 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan {
})
}
+/// Looks for correlating expressions: equality expressions with one field from the subquery, and
+/// one not in the subquery (closed upon from outer scope)
+///
+/// # Arguments
+///
+/// * `exprs` - List of expressions that may or may not be joins
+/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema
+///
+/// # Return value
+///
+/// Tuple of (expressions containing joins, remaining non-join expressions)
+pub fn find_join_exprs(
+ exprs: Vec<&Expr>,
+ schema: &DFSchemaRef,
+) -> Result<(Vec<Expr>, Vec<Expr>)> {
+ let fields: HashSet<_> = schema
+ .fields()
+ .iter()
+ .map(|it| it.qualified_name())
+ .collect();
+
+ let mut joins = vec![];
+ let mut others = vec![];
+ for filter in exprs.iter() {
+ let (left, op, right) = match filter {
+ Expr::BinaryExpr { left, op, right } => (*left.clone(), *op, *right.clone()),
+ _ => {
+ others.push((*filter).clone());
+ continue;
+ }
+ };
+ let left = match left {
+ Expr::Column(c) => c,
+ _ => {
+ others.push((*filter).clone());
+ continue;
+ }
+ };
+ let right = match right {
+ Expr::Column(c) => c,
+ _ => {
+ others.push((*filter).clone());
+ continue;
+ }
+ };
+ if fields.contains(&left.flat_name()) && fields.contains(&right.flat_name()) {
+ others.push((*filter).clone());
+ continue; // both columns present (none closed-upon)
+ }
+ if !fields.contains(&left.flat_name()) && !fields.contains(&right.flat_name()) {
+ others.push((*filter).clone());
+ continue; // neither column present (syntax error?)
+ }
+ match op {
+ Operator::Eq => {}
+ Operator::NotEq => {}
+ _ => {
+ plan_err!(format!("can't optimize {} column comparison", op))?;
+ }
+ }
+
+ joins.push((*filter).clone())
+ }
+
+ Ok((joins, others))
+}
+
+/// Extracts correlating columns from expressions
+///
+/// # Arguments
+///
+/// * `exprs` - List of expressions that correlate a subquery to an outer scope
+/// * `fields` - HashSet of fully qualified (table.col) fields in subquery schema
+/// * `include_negated` - true if `NotEq` counts as a join operator
+///
+/// # Return value
+///
+/// Tuple of (outer-scope cols, subquery cols, non-correlation expressions)
+pub fn exprs_to_join_cols(
+ exprs: &[Expr],
+ schema: &DFSchemaRef,
+ include_negated: bool,
+) -> Result<(Vec<Column>, Vec<Column>, Option<Expr>)> {
+ let fields: HashSet<_> = schema
+ .fields()
+ .iter()
+ .map(|it| it.qualified_name())
+ .collect();
+
+ let mut joins: Vec<(String, String)> = vec![];
+ let mut others: Vec<Expr> = vec![];
+ for filter in exprs.iter() {
+ let (left, op, right) = match filter {
+ Expr::BinaryExpr { left, op, right } => (*left.clone(), *op, *right.clone()),
+ _ => plan_err!("Invalid correlation expression!")?,
+ };
+ match op {
+ Operator::Eq => {}
+ Operator::NotEq => {
+ if !include_negated {
+ others.push((*filter).clone());
+ continue;
+ }
+ }
+ _ => plan_err!(format!("Correlation operator unsupported: {}", op))?,
+ }
+ let left = left.try_into_col()?;
+ let right = right.try_into_col()?;
+ let sorted = if fields.contains(&left.flat_name()) {
+ (right.flat_name(), left.flat_name())
+ } else {
+ (left.flat_name(), right.flat_name())
+ };
+ joins.push(sorted);
+ }
+
+ let (left_cols, right_cols): (Vec<_>, Vec<_>) = joins
+ .into_iter()
+ .map(|(l, r)| (Column::from(l.as_str()), Column::from(r.as_str())))
+ .unzip();
+ let pred = combine_filters(&others);
+
+ Ok((left_cols, right_cols, pred))
+}
+
+/// Returns the first (and only) element in a slice, or an error
+///
+/// # Arguments
+///
+/// * `slice` - The slice to extract from
+///
+/// # Return value
+///
+/// The first element, or an error
+pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
+ match slice {
+ [it] => Ok(it),
+ [] => plan_err!("No items found!"),
+ _ => plan_err!("More than one item found!"),
+ }
+}
+
+/// Merge and deduplicate two sets Column slices
+///
+/// # Arguments
+///
+/// * `a` - A tuple of slices of Columns
+/// * `b` - A tuple of slices of Columns
+///
+/// # Return value
+///
+/// The deduplicated union of the two slices
+pub fn merge_cols(
+ a: (&[Column], &[Column]),
+ b: (&[Column], &[Column]),
+) -> (Vec<Column>, Vec<Column>) {
+ let e =
+ a.0.iter()
+ .map(|it| it.flat_name())
+ .chain(a.1.iter().map(|it| it.flat_name()))
+ .map(|it| Column::from(it.as_str()));
+ let f =
+ b.0.iter()
+ .map(|it| it.flat_name())
+ .chain(b.1.iter().map(|it| it.flat_name()))
+ .map(|it| Column::from(it.as_str()));
+ let mut g = e.zip(f).collect::<Vec<_>>();
+ g.dedup();
+ g.into_iter().unzip()
+}
+
+/// Change the relation on a slice of Columns
+///
+/// # Arguments
+///
+/// * `new_table` - The table/relation for the new columns
+/// * `cols` - A slice of Columns
+///
+/// # Return value
+///
+/// A new slice of columns, now belonging to the new table
+pub fn swap_table(new_table: &str, cols: &[Column]) -> Vec<Column> {
+ cols.iter()
+ .map(|it| Column {
+ relation: Some(new_table.to_string()),
+ name: it.name.clone(),
+ })
+ .collect()
+}
+
+pub fn alias_cols(cols: &[Column]) -> Vec<Expr> {
+ cols.iter()
+ .map(|it| col(it.flat_name().as_str()).alias(it.name.as_str()))
+ .collect()
+}
+
#[cfg(test)]
mod tests {
use super::*;