You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/03 10:45:45 UTC

[arrow-datafusion] branch master updated: fix `push_down_filter` for pushing filters on grouping columns rather than aggregate columns (#4447)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 050969242 fix `push_down_filter` for pushing filters on grouping columns rather than aggregate columns (#4447)
050969242 is described below

commit 05096924218782c75bbe1c37b22765f6afb5a63e
Author: jakevin <ja...@gmail.com>
AuthorDate: Sat Dec 3 18:45:41 2022 +0800

    fix `push_down_filter` for pushing filters on grouping columns rather than aggregate columns (#4447)
    
    * fix `push_down_filter` push column instead of Expr.
    
    * remove collect to avoid performance loss
    
    * add UT
    
    * enhance filter push through agg
    
    * add comment
    
    * polish
    
    * remove wrong UT.
---
 datafusion/optimizer/src/push_down_filter.rs   | 93 ++++++++++----------------
 datafusion/optimizer/tests/integration-test.rs | 12 ++++
 2 files changed, 49 insertions(+), 56 deletions(-)

diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs
index e59590df5..b61ba0f21 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -17,7 +17,6 @@
 use crate::utils::conjunction;
 use crate::{utils, OptimizerConfig, OptimizerRule};
 use datafusion_common::{Column, DFSchema, DataFusionError, Result};
-use datafusion_expr::utils::exprlist_to_columns;
 use datafusion_expr::{
     and,
     expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
@@ -620,19 +619,12 @@ impl OptimizerRule for PushDownFilter {
                 })
             }
             LogicalPlan::Aggregate(agg) => {
-                // An aggregate's aggregate columns are _not_ filter-commutable => collect these:
-                // * columns whose aggregation expression depends on
-                // * the aggregation columns themselves
-
-                // construct set of columns that `aggr_expr` depends on
-                let mut used_columns = HashSet::new();
-                exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?;
-                let agg_columns = agg
-                    .aggr_expr
+                // We can push down Predicate which in groupby_expr.
+                let group_expr_columns = agg
+                    .group_expr
                     .iter()
-                    .map(|x| Ok(Column::from_name(x.display_name()?)))
+                    .map(|e| Ok(Column::from_qualified_name(&(e.display_name()?))))
                     .collect::<Result<HashSet<_>>>()?;
-                used_columns.extend(agg_columns);
 
                 let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
                     filter.predicate().clone(),
@@ -641,20 +633,27 @@ impl OptimizerRule for PushDownFilter {
                 let mut keep_predicates = vec![];
                 let mut push_predicates = vec![];
                 for expr in predicates {
-                    let columns = expr.to_columns()?;
-                    if columns.is_empty()
-                        || !columns
-                            .intersection(&used_columns)
-                            .collect::<HashSet<_>>()
-                            .is_empty()
-                    {
-                        keep_predicates.push(expr);
-                    } else {
+                    let cols = expr.to_columns()?;
+                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
                         push_predicates.push(expr);
+                    } else {
+                        keep_predicates.push(expr);
                     }
                 }
 
-                let child = match conjunction(push_predicates) {
+                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
+                // After push, we need to replace `a+b` with Column(a)+Column(b)
+                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
+                let mut replace_map = HashMap::new();
+                for expr in &agg.group_expr {
+                    replace_map.insert(expr.display_name()?, expr.clone());
+                }
+                let replaced_push_predicates = push_predicates
+                    .iter()
+                    .map(|expr| replace_cols_by_name(expr.clone(), &replace_map))
+                    .collect::<Result<Vec<_>>>()?;
+
+                let child = match conjunction(replaced_push_predicates) {
                     Some(predicate) => LogicalPlan::Filter(Filter::try_new(
                         predicate,
                         Arc::new((*agg.input).clone()),
@@ -881,40 +880,30 @@ mod tests {
     }
 
     #[test]
-    fn filter_keep_agg() -> Result<()> {
-        let table_scan = test_table_scan()?;
-        let plan = LogicalPlanBuilder::from(table_scan)
-            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
-            .filter(col("b").gt(lit(10i64)))?
+    fn push_agg_need_replace_expr() -> Result<()> {
+        let plan = LogicalPlanBuilder::from(test_table_scan()?)
+            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
+            .filter(col("test.b + test.a").gt(lit(10i64)))?
             .build()?;
-        // filter of aggregate is after aggregation since they are non-commutative
-        let expected = "\
-            Filter: b > Int64(10)\
-            \n  Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
-            \n    TableScan: test";
+        let expected =
+            "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\
+        \n  Filter: test.b + test.a > Int64(10)\
+        \n    TableScan: test";
         assert_optimized_plan_eq(&plan, expected)
     }
 
     #[test]
-    fn filter_keep_partial_agg() -> Result<()> {
+    fn filter_keep_agg() -> Result<()> {
         let table_scan = test_table_scan()?;
-        let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64)));
-        let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64)));
-        let filter = f1.or(f2);
         let plan = LogicalPlanBuilder::from(table_scan)
             .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
-            .filter(filter)?
+            .filter(col("b").gt(lit(10i64)))?
             .build()?;
         // filter of aggregate is after aggregation since they are non-commutative
-        // (c =1 AND b > 2) OR (c = 1 AND b > 3)
-        // rewrite to CNF
-        // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3)
-
         let expected = "\
-        Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\
-        \n  Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
-        \n    Filter: test.c = Int64(1) OR test.c = Int64(1)\
-        \n      TableScan: test";
+            Filter: b > Int64(10)\
+            \n  Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
+            \n    TableScan: test";
         assert_optimized_plan_eq(&plan, expected)
     }
 
@@ -1870,17 +1859,9 @@ mod tests {
     #[async_trait]
     impl TableSource for PushDownProvider {
         fn schema(&self) -> SchemaRef {
-            Arc::new(arrow::datatypes::Schema::new(vec![
-                arrow::datatypes::Field::new(
-                    "a",
-                    arrow::datatypes::DataType::Int32,
-                    true,
-                ),
-                arrow::datatypes::Field::new(
-                    "b",
-                    arrow::datatypes::DataType::Int32,
-                    true,
-                ),
+            Arc::new(Schema::new(vec![
+                Field::new("a", DataType::Int32, true),
+                Field::new("b", DataType::Int32, true),
             ]))
         }
 
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index 457ea833e..701d1a84c 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -304,6 +304,18 @@ fn join_keys_in_subquery_alias_1() {
     assert_eq!(expected, format!("{:?}", plan));
 }
 
+#[test]
+fn push_down_filter_groupby_expr_contains_alias() {
+    let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3";
+    let plan = test_sql(sql).unwrap();
+    let expected = "Projection: c, COUNT(UInt8(1))\
+    \n  Projection: test.col_int32 + test.col_uint32 AS c, COUNT(UInt8(1))\
+    \n    Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(UInt8(1))]]\
+    \n      Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\
+    \n        TableScan: test projection=[col_int32, col_uint32]";
+    assert_eq!(expected, format!("{:?}", plan));
+}
+
 fn test_sql(sql: &str) -> Result<LogicalPlan> {
     // parse the SQL
     let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...