You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/10/10 19:58:16 UTC

[arrow-datafusion] branch master updated: Fix optimizer regression with simplifying expressions in subquery filters (#3764)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 80228279d Fix optimizer regression with simplifying expressions in subquery filters (#3764)
80228279d is described below

commit 80228279d61c10903cd9707fafcbccb8b15d5e1c
Author: Andy Grove <an...@gmail.com>
AuthorDate: Mon Oct 10 13:58:10 2022 -0600

    Fix optimizer regression with simplifying expressions in subquery filters (#3764)
---
 datafusion/core/tests/sql/subqueries.rs        | 12 ++++++------
 datafusion/optimizer/src/optimizer.rs          |  4 ++++
 datafusion/optimizer/tests/integration-test.rs | 22 ++++++++++++++++++++++
 3 files changed, 32 insertions(+), 6 deletions(-)

diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs
index f91018d8b..a5b246be4 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -336,10 +336,10 @@ order by s_name;
               Projection: part.p_partkey AS p_partkey, alias=__sq_1
                 Filter: part.p_name LIKE Utf8("forest%")
                   TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")]
-            Projection: lineitem.l_partkey, lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
+            Projection: lineitem.l_partkey, lineitem.l_suppkey, Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
               Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]]
-                Filter: lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
-                  TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
+                Filter: lineitem.l_shipdate >= Date32("8766")
+                  TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"#
         .to_string();
     assert_eq!(actual, expected);
 
@@ -393,8 +393,8 @@ order by cntrycode;"#;
                 TableScan: orders projection=[o_custkey]
               Projection: AVG(customer.c_acctbal) AS __value, alias=__sq_1
                 Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]
-                  Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
-                    TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
+                  Filter: CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
+                    TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[CAST(customer.c_acctbal AS Decimal128(30, 15)) > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
         .to_string();
     assert_eq!(actual, expected);
 
@@ -453,7 +453,7 @@ order by value desc;
               TableScan: supplier projection=[s_suppkey, s_nationkey]
             Filter: nation.n_name = Utf8("GERMANY")
               TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]
-        Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1
+        Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, alias=__sq_1
           Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
             Inner Join: supplier.s_nationkey = nation.n_nationkey
               Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs
index aa10cd8a7..87e4d1ffc 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -144,6 +144,10 @@ impl Optimizer {
             Arc::new(DecorrelateWhereIn::new()),
             Arc::new(ScalarSubqueryToJoin::new()),
             Arc::new(SubqueryFilterToJoin::new()),
+            // simplify expressions does not simplify expressions in subqueries, so we
+            // run it again after running the optimizations that potentially converted
+            // subqueries to joins
+            Arc::new(SimplifyExpressions::new()),
             Arc::new(EliminateFilter::new()),
             Arc::new(ReduceCrossJoin::new()),
             Arc::new(CommonSubexprEliminate::new()),
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index f6fe685ee..12a5b4447 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -52,6 +52,28 @@ fn case_when() -> Result<()> {
     Ok(())
 }
 
+#[test]
+fn subquery_filter_with_cast() -> Result<()> {
+    // regression test for https://github.com/apache/arrow-datafusion/issues/3760
+    let sql = "SELECT col_int32 FROM test \
+    WHERE col_int32 > (\
+      SELECT AVG(col_int32) FROM test \
+      WHERE col_utf8 BETWEEN '2002-05-08' \
+        AND (cast('2002-05-08' as date) + interval '5 days')\
+    )";
+    let plan = test_sql(sql)?;
+    let expected =
+        "Projection: test.col_int32\n  Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value\
+        \n    CrossJoin:\
+        \n      TableScan: test projection=[col_int32]\
+        \n      Projection: AVG(test.col_int32) AS __value, alias=__sq_1\
+        \n        Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\
+        \n          Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\
+        \n            TableScan: test projection=[col_int32, col_utf8]";
+    assert_eq!(expected, format!("{:?}", plan));
+    Ok(())
+}
+
 #[test]
 fn case_when_aggregate() -> Result<()> {
     let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8";