You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "alamb (via GitHub)" <gi...@apache.org> on 2023/04/10 19:52:19 UTC

[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #5837: Add new physical rule CombinePartialFinalAggregate

alamb commented on code in PR #5837:
URL: https://github.com/apache/arrow-datafusion/pull/5837#discussion_r1162025055


##########
datafusion/core/src/physical_plan/aggregates/mod.rs:
##########
@@ -147,6 +149,24 @@ impl PhysicalGroupBy {
     }
 }
 
+impl PartialEq for PhysicalGroupBy {
+    fn eq(&self, other: &PhysicalGroupBy) -> bool {
+        self.expr.len() == other.expr.len()
+            && self
+                .expr
+                .iter()
+                .zip(other.expr.iter())
+                .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)

Review Comment:
   I wondered why this needed to be manually derived, so I tried removing it and got this error:
   
   ```
   error[E0369]: binary operation `==` cannot be applied to type `Vec<(Arc<dyn PhysicalExpr>, std::string::String)>`
     --> datafusion/core/src/physical_plan/aggregates/mod.rs:91:5
      |
   88 | #[derive(Clone, Debug, Default, PartialEq)]
      |                                 --------- in this derive macro expansion
   ...
   91 |     expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
      |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      |
      = note: this error originates in the derive macro `PartialEq` (in Nightly builds, run with -Z macro-backtrace for more info)
   
   ```



##########
datafusion/core/tests/sql/joins.rs:
##########
@@ -1772,6 +1772,171 @@ async fn right_semi_join() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn join_and_aggregate_on_same_key() -> Result<()> {
+    let ctx = create_join_context("t1_id", "t2_id", true)?;
+    let sql = "select distinct(t1.t1_id) from t1 inner join t2 on t1.t1_id = t2.t2_id";
+
+    // assert logical plan
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
+    let plan = dataframe.into_optimized_plan().unwrap();
+
+    let expected = vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Aggregate: groupBy=[[t1.t1_id]], aggr=[[]] [t1_id:UInt32;N]",
+        "    Projection: t1.t1_id [t1_id:UInt32;N]",
+        "      Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t2_id:UInt32;N]",
+        "        TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]",
+        "        TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+    ];
+
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let msg = format!("Creating physical plan for '{sql}'");
+    let dataframe = ctx.sql(sql).await.expect(&msg);
+    let physical_plan = dataframe.create_physical_plan().await?;
+    let expected =
+        vec![
+            "AggregateExec: mode=Single, gby=[t1_id@0 as t1_id], aggr=[]",

Review Comment:
   Is it correct that this plan can use a single aggregate because is is already partitioned on the group key (t1_id) after the join



##########
datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs:
##########
@@ -0,0 +1,120 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! CombinePartialFinalAggregate optimizer rule checks the adjacent Partial and Final AggregateExecs
+//! and try to combine them if necessary
+use crate::error::Result;
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
+use crate::physical_plan::ExecutionPlan;
+use datafusion_common::config::ConfigOptions;
+use std::sync::Arc;
+
+use datafusion_common::tree_node::{Transformed, TreeNode};
+
+/// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs
+/// into a Single AggregateExec if their grouping exprs and aggregate exprs equal.
+///
+/// This rule should be applied after the EnforceDistribution and EnforceSorting rules
+///
+#[derive(Default)]
+pub struct CombinePartialFinalAggregate {}
+
+impl CombinePartialFinalAggregate {
+    #[allow(missing_docs)]
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
+    fn optimize(
+        &self,
+        plan: Arc<dyn ExecutionPlan>,
+        _config: &ConfigOptions,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        plan.transform_down(&|plan| {
+            let transformed = plan.as_any().downcast_ref::<AggregateExec>().and_then(
+                |AggregateExec {
+                     mode: final_mode,
+                     input: final_input,
+                     group_by: final_group_by,
+                     aggr_expr: final_aggr_expr,
+                     ..
+                 }| {
+                    if matches!(
+                        final_mode,
+                        AggregateMode::Final | AggregateMode::FinalPartitioned
+                    ) {
+                        final_input
+                            .as_any()
+                            .downcast_ref::<AggregateExec>()
+                            .and_then(
+                                |AggregateExec {
+                                     mode: input_mode,
+                                     input: partial_input,
+                                     group_by: input_group_by,
+                                     aggr_expr: input_aggr_expr,
+                                     input_schema,
+                                     ..
+                                 }| {
+                                    if matches!(input_mode, AggregateMode::Partial)
+                                        && final_group_by.eq(input_group_by)
+                                        && final_aggr_expr.len() == input_aggr_expr.len()
+                                        && final_aggr_expr
+                                            .iter()
+                                            .zip(input_aggr_expr.iter())
+                                            .all(|(final_expr, partial_expr)| {
+                                                final_expr.eq(partial_expr)
+                                            })
+                                    {
+                                        AggregateExec::try_new(
+                                            AggregateMode::Single,
+                                            input_group_by.clone(),
+                                            input_aggr_expr.to_vec(),
+                                            partial_input.clone(),
+                                            input_schema.clone(),
+                                        )
+                                        .ok()
+                                        .map(Arc::new)
+                                    } else {
+                                        None
+                                    }
+                                },
+                            )
+                    } else {
+                        None
+                    }
+                },
+            );
+
+            Ok(if let Some(transformed) = transformed {
+                Transformed::Yes(transformed)
+            } else {
+                Transformed::No(plan)
+            })
+        })
+    }
+
+    fn name(&self) -> &str {
+        "CombinePartialFinalAggregate"
+    }
+
+    fn schema_check(&self) -> bool {
+        true
+    }
+}

Review Comment:
   I think it would help to add unit tests to this optimizer so we can see what it does in isolation (and test things like different agg exprs not being matched)



##########
datafusion/core/src/physical_plan/aggregates/mod.rs:
##########
@@ -65,6 +65,8 @@ pub enum AggregateMode {
     /// with Hash repartitioning on the group keys. If a group key is
     /// duplicated, duplicate groups would be produced
     FinalPartitioned,
+    /// Single aggregate is a combination of Partial and Final aggregate mode

Review Comment:
   ```suggestion
       /// Applies the entire logical aggregation operation in a single operator,
       /// as opposed to Partial / Final modes which apply the logical aggregation using
       /// two operators.  
   ```



##########
datafusion/physical-expr/src/aggregate/mod.rs:
##########
@@ -56,7 +56,7 @@ pub(crate) mod variance;
 /// * knows how to create its accumulator
 /// * knows its accumulator's state's field
 /// * knows the expressions from whose its accumulator will receive values
-pub trait AggregateExpr: Send + Sync + Debug {
+pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> {

Review Comment:
   this is an API change. I understand why we need a PartialEq against `dyn Any` but it might be somewhat confusing to others
   
   could you add a documentation describing how to do so (perhaps pointing at the `down_cast_any_ref` utility function)



##########
datafusion/physical-expr/src/aggregate/utils.rs:
##########
@@ -31,3 +34,17 @@ pub fn get_accum_scalar_values_as_arrays(
         .map(|s| s.to_array_of_size(1))
         .collect::<Vec<_>>())
 }
+
+pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {

Review Comment:
   Can you  please document what this function does (with an example) given it is a new `pub` function?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org