You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/10/10 19:58:16 UTC
[arrow-datafusion] branch master updated: Fix optimizer regression with simplifying expressions in subquery filters (#3764)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 80228279d Fix optimizer regression with simplifying expressions in subquery filters (#3764)
80228279d is described below
commit 80228279d61c10903cd9707fafcbccb8b15d5e1c
Author: Andy Grove <an...@gmail.com>
AuthorDate: Mon Oct 10 13:58:10 2022 -0600
Fix optimizer regression with simplifying expressions in subquery filters (#3764)
---
datafusion/core/tests/sql/subqueries.rs | 12 ++++++------
datafusion/optimizer/src/optimizer.rs | 4 ++++
datafusion/optimizer/tests/integration-test.rs | 22 ++++++++++++++++++++++
3 files changed, 32 insertions(+), 6 deletions(-)
diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
index f91018d8b..a5b246be4 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -336,10 +336,10 @@ order by s_name;
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")]
- Projection: lineitem.l_partkey, lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
+ Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]]
- Filter: lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
- TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
+ Filter: lineitem.l_shipdate >= Date32("8766")
+ TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"#
.to_string();
assert_eq!(actual, expected);
@@ -393,8 +393,8 @@ order by cntrycode;"#;
TableScan: orders projection=[o_custkey]
Projection: AVG(customer.c_acctbal) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]
- Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
- TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
+ Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
+ TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
.to_string();
assert_eq!(actual, expected);
@@ -453,7 +453,7 @@ order by value desc;
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]
- Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1
+ Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs
index aa10cd8a7..87e4d1ffc 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -144,6 +144,10 @@ impl Optimizer {
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Arc::new(SubqueryFilterToJoin::new()),
+ // simplify expressions does not simplify expressions in subqueries, so we
+ // run it again after running the optimizations that potentially converted
+ // subqueries to joins
+ Arc::new(SimplifyExpressions::new()),
Arc::new(EliminateFilter::new()),
Arc::new(ReduceCrossJoin::new()),
Arc::new(CommonSubexprEliminate::new()),
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index f6fe685ee..12a5b4447 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -52,6 +52,28 @@ fn case_when() -> Result<()> {
Ok(())
}
+#[test]
+fn subquery_filter_with_cast() -> Result<()> {
+ // regression test for https://github.com/apache/arrow-datafusion/issues/3760
+ let sql = "SELECT col_int32 FROM test \
+ WHERE col_int32 > (\
+ SELECT AVG(col_int32) FROM test \
+ WHERE col_utf8 BETWEEN '2002-05-08' \
+ AND (cast('2002-05-08' as date) + interval '5 days')\
+ )";
+ let plan = test_sql(sql)?;
+ let expected =
+ "Projection: test.col_int32\n Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value\
+ \n CrossJoin:\
+ \n TableScan: test projection=[col_int32]\
+ \n Projection: AVG(test.col_int32) AS __value, alias=__sq_1\
+ \n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\
+ \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\
+ \n TableScan: test projection=[col_int32, col_utf8]";
+ assert_eq!(expected, format!("{:?}", plan));
+ Ok(())
+}
+
#[test]
fn case_when_aggregate() -> Result<()> {
let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";