You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/03 10:45:45 UTC
[arrow-datafusion] branch master updated: fix `push_down_filter` for pushing filters on grouping columns rather than aggregate columns (#4447)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 050969242 fix `push_down_filter` for pushing filters on grouping columns rather than aggregate columns (#4447)
050969242 is described below
commit 05096924218782c75bbe1c37b22765f6afb5a63e
Author: jakevin <ja...@gmail.com>
AuthorDate: Sat Dec 3 18:45:41 2022 +0800
fix `push_down_filter` for pushing filters on grouping columns rather than aggregate columns (#4447)
* fix `push_down_filter` push column instead of Expr.
* remove collect to avoid performance loss
* add UT
* enhance filter push through agg
* add comment
* polish
* remove wrong UT.
---
datafusion/optimizer/src/push_down_filter.rs | 93 ++++++++++----------------
datafusion/optimizer/tests/integration-test.rs | 12 ++++
2 files changed, 49 insertions(+), 56 deletions(-)
diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs
index e59590df5..b61ba0f21 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -17,7 +17,6 @@
use crate::utils::conjunction;
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
-use datafusion_expr::utils::exprlist_to_columns;
use datafusion_expr::{
and,
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
@@ -620,19 +619,12 @@ impl OptimizerRule for PushDownFilter {
})
}
LogicalPlan::Aggregate(agg) => {
- // An aggregate's aggregate columns are _not_ filter-commutable => collect these:
- // * columns whose aggregation expression depends on
- // * the aggregation columns themselves
-
- // construct set of columns that `aggr_expr` depends on
- let mut used_columns = HashSet::new();
- exprlist_to_columns(&agg.aggr_expr, &mut used_columns)?;
- let agg_columns = agg
- .aggr_expr
+ // We can push down Predicate which in groupby_expr.
+ let group_expr_columns = agg
+ .group_expr
.iter()
- .map(|x| Ok(Column::from_name(x.display_name()?)))
+ .map(|e| Ok(Column::from_qualified_name(&(e.display_name()?))))
.collect::<Result<HashSet<_>>>()?;
- used_columns.extend(agg_columns);
let predicates = utils::split_conjunction_owned(utils::cnf_rewrite(
filter.predicate().clone(),
@@ -641,20 +633,27 @@ impl OptimizerRule for PushDownFilter {
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
- let columns = expr.to_columns()?;
- if columns.is_empty()
- || !columns
- .intersection(&used_columns)
- .collect::<HashSet<_>>()
- .is_empty()
- {
- keep_predicates.push(expr);
- } else {
+ let cols = expr.to_columns()?;
+ if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
+ } else {
+ keep_predicates.push(expr);
}
}
- let child = match conjunction(push_predicates) {
+ // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
+ // After push, we need to replace `a+b` with Column(a)+Column(b)
+ // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
+ let mut replace_map = HashMap::new();
+ for expr in &agg.group_expr {
+ replace_map.insert(expr.display_name()?, expr.clone());
+ }
+ let replaced_push_predicates = push_predicates
+ .iter()
+ .map(|expr| replace_cols_by_name(expr.clone(), &replace_map))
+ .collect::<Result<Vec<_>>>()?;
+
+ let child = match conjunction(replaced_push_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new((*agg.input).clone()),
@@ -881,40 +880,30 @@ mod tests {
}
#[test]
- fn filter_keep_agg() -> Result<()> {
- let table_scan = test_table_scan()?;
- let plan = LogicalPlanBuilder::from(table_scan)
- .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
- .filter(col("b").gt(lit(10i64)))?
+ fn push_agg_need_replace_expr() -> Result<()> {
+ let plan = LogicalPlanBuilder::from(test_table_scan()?)
+ .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
+ .filter(col("test.b + test.a").gt(lit(10i64)))?
.build()?;
- // filter of aggregate is after aggregation since they are non-commutative
- let expected = "\
- Filter: b > Int64(10)\
- \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
- \n TableScan: test";
+ let expected =
+ "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\
+ \n Filter: test.b + test.a > Int64(10)\
+ \n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
#[test]
- fn filter_keep_partial_agg() -> Result<()> {
+ fn filter_keep_agg() -> Result<()> {
let table_scan = test_table_scan()?;
- let f1 = col("c").eq(lit(1i64)).and(col("b").gt(lit(2i64)));
- let f2 = col("c").eq(lit(1i64)).and(col("b").gt(lit(3i64)));
- let filter = f1.or(f2);
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
- .filter(filter)?
+ .filter(col("b").gt(lit(10i64)))?
.build()?;
// filter of aggregate is after aggregation since they are non-commutative
- // (c =1 AND b > 2) OR (c = 1 AND b > 3)
- // rewrite to CNF
- // (c = 1 OR c = 1) [can pushDown] AND (c = 1 OR b > 3) AND (b > 2 OR C = 1) AND (b > 2 OR b > 3)
-
let expected = "\
- Filter: (test.c = Int64(1) OR b > Int64(3)) AND (b > Int64(2) OR test.c = Int64(1)) AND (b > Int64(2) OR b > Int64(3))\
- \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
- \n Filter: test.c = Int64(1) OR test.c = Int64(1)\
- \n TableScan: test";
+ Filter: b > Int64(10)\
+ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\
+ \n TableScan: test";
assert_optimized_plan_eq(&plan, expected)
}
@@ -1870,17 +1859,9 @@ mod tests {
#[async_trait]
impl TableSource for PushDownProvider {
fn schema(&self) -> SchemaRef {
- Arc::new(arrow::datatypes::Schema::new(vec![
- arrow::datatypes::Field::new(
- "a",
- arrow::datatypes::DataType::Int32,
- true,
- ),
- arrow::datatypes::Field::new(
- "b",
- arrow::datatypes::DataType::Int32,
- true,
- ),
+ Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int32, true),
+ Field::new("b", DataType::Int32, true),
]))
}
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index 457ea833e..701d1a84c 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -304,6 +304,18 @@ fn join_keys_in_subquery_alias_1() {
assert_eq!(expected, format!("{:?}", plan));
}
+#[test]
+fn push_down_filter_groupby_expr_contains_alias() {
+ let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3";
+ let plan = test_sql(sql).unwrap();
+ let expected = "Projection: c, COUNT(UInt8(1))\
+ \n Projection: test.col_int32 + test.col_uint32 AS c, COUNT(UInt8(1))\
+ \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(UInt8(1))]]\
+ \n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\
+ \n TableScan: test projection=[col_int32, col_uint32]";
+ assert_eq!(expected, format!("{:?}", plan));
+}
+
fn test_sql(sql: &str) -> Result<LogicalPlan> {
// parse the SQL
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...