You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/08/31 14:58:38 UTC

[arrow-datafusion] branch master updated: Add `Aggregate::try new` with validation checks (#3286)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d37016de Add `Aggregate::try new` with validation checks (#3286)
3d37016de is described below

commit 3d37016de3647d90e6f78fd0e106142923799969
Author: Andy Grove <an...@gmail.com>
AuthorDate: Wed Aug 31 08:58:32 2022 -0600

    Add `Aggregate::try new` with validation checks (#3286)
    
    * Add Aggregate::try_new with validation checks
    
    * fix calculation of number of grouping expressions
    
    * use suggested error message
---
 datafusion/core/src/physical_plan/planner.rs       | 10 ++++----
 datafusion/expr/src/logical_plan/builder.rs        |  8 +++---
 datafusion/expr/src/logical_plan/plan.rs           | 30 +++++++++++++++++++++-
 datafusion/expr/src/utils.rs                       | 28 +++++++++++++++-----
 .../optimizer/src/common_subexpr_eliminate.rs      | 12 ++++-----
 datafusion/optimizer/src/projection_push_down.rs   | 12 ++++-----
 .../optimizer/src/single_distinct_to_groupby.rs    | 25 +++++++++---------
 7 files changed, 84 insertions(+), 41 deletions(-)

diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 747cd1a20..a67e2dac7 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -645,12 +645,12 @@ impl DefaultPhysicalPlanner {
                 LogicalPlan::Distinct(Distinct {input}) => {
                     // Convert distinct to groupby with no aggregations
                     let group_expr = expand_wildcard(input.schema(), input)?;
-                    let aggregate =  LogicalPlan::Aggregate(Aggregate {
-                            input: input.clone(),
+                    let aggregate =  LogicalPlan::Aggregate(Aggregate::try_new(
+                            input.clone(),
                             group_expr,
-                            aggr_expr: vec![],
-                            schema: input.schema().clone()
-                        }
+                            vec![],
+                            input.schema().clone()
+                    )?
                     );
                     Ok(self.create_initial_plan(&aggregate, session_state).await?)
                 }
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 2946a74af..41ba95140 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -701,12 +701,12 @@ impl LogicalPlanBuilder {
             exprlist_to_fields(all_expr, &self.plan)?,
             self.plan.schema().metadata().clone(),
         )?;
-        Ok(Self::from(LogicalPlan::Aggregate(Aggregate {
-            input: Arc::new(self.plan.clone()),
+        Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new(
+            Arc::new(self.plan.clone()),
             group_expr,
             aggr_expr,
-            schema: DFSchemaRef::new(aggr_schema),
-        })))
+            DFSchemaRef::new(aggr_schema),
+        )?)))
     }
 
     /// Create an expression to represent the explanation of the plan
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 2d5eb4680..cec55bfc1 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -17,7 +17,7 @@
 
 use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
 use crate::logical_plan::extension::UserDefinedLogicalNode;
-use crate::utils::exprlist_to_fields;
+use crate::utils::{exprlist_to_fields, grouping_set_expr_count};
 use crate::{Expr, TableProviderFilterPushDown, TableSource};
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use datafusion_common::{plan_err, Column, DFSchema, DFSchemaRef, DataFusionError};
@@ -1314,6 +1314,34 @@ pub struct Aggregate {
 }
 
 impl Aggregate {
+    pub fn try_new(
+        input: Arc<LogicalPlan>,
+        group_expr: Vec<Expr>,
+        aggr_expr: Vec<Expr>,
+        schema: DFSchemaRef,
+    ) -> datafusion_common::Result<Self> {
+        if group_expr.is_empty() && aggr_expr.is_empty() {
+            return Err(DataFusionError::Plan(
+                "Aggregate requires at least one grouping or aggregate expression"
+                    .to_string(),
+            ));
+        }
+        let group_expr_count = grouping_set_expr_count(&group_expr)?;
+        if schema.fields().len() != group_expr_count + aggr_expr.len() {
+            return Err(DataFusionError::Plan(format!(
+                "Aggregate schema has wrong number of fields. Expected {} got {}",
+                group_expr_count + aggr_expr.len(),
+                schema.fields().len()
+            )));
+        }
+        Ok(Self {
+            input,
+            group_expr,
+            aggr_expr,
+            schema,
+        })
+    }
+
     pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Aggregate> {
         match plan {
             LogicalPlan::Aggregate(it) => Ok(it),
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 7d3f78b8f..e748536d7 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -45,6 +45,22 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result
     Ok(())
 }
 
+/// Count the number of distinct exprs in a list of group by expressions. If the
+/// first element is a `GroupingSet` expression then it must be the only expr.
+pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
+    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
+        if group_expr.len() > 1 {
+            return Err(DataFusionError::Plan(
+                "Invalid group by expressions, GroupingSet must be the only expression"
+                    .to_string(),
+            ));
+        }
+        Ok(grouping_set.distinct_expr().len())
+    } else {
+        Ok(group_expr.len())
+    }
+}
+
 /// Find all distinct exprs in a list of group by expressions. If the
 /// first element is a `GroupingSet` expression then it must be the only expr.
 pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
@@ -395,12 +411,12 @@ pub fn from_plan(
         })),
         LogicalPlan::Aggregate(Aggregate {
             group_expr, schema, ..
-        }) => Ok(LogicalPlan::Aggregate(Aggregate {
-            group_expr: expr[0..group_expr.len()].to_vec(),
-            aggr_expr: expr[group_expr.len()..].to_vec(),
-            input: Arc::new(inputs[0].clone()),
-            schema: schema.clone(),
-        })),
+        }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new(
+            Arc::new(inputs[0].clone()),
+            expr[0..group_expr.len()].to_vec(),
+            expr[group_expr.len()..].to_vec(),
+            schema.clone(),
+        )?)),
         LogicalPlan::Sort(Sort { .. }) => Ok(LogicalPlan::Sort(Sort {
             expr: expr.to_vec(),
             input: Arc::new(inputs[0].clone()),
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 305283d99..f015aeaa0 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -189,12 +189,12 @@ fn optimize(
             let new_aggr_expr = new_expr.pop().unwrap();
             let new_group_expr = new_expr.pop().unwrap();
 
-            Ok(LogicalPlan::Aggregate(Aggregate {
-                input: Arc::new(new_input),
-                group_expr: new_group_expr,
-                aggr_expr: new_aggr_expr,
-                schema: schema.clone(),
-            }))
+            Ok(LogicalPlan::Aggregate(Aggregate::try_new(
+                Arc::new(new_input),
+                new_group_expr,
+                new_aggr_expr,
+                schema.clone(),
+            )?))
         }
         LogicalPlan::Sort(Sort { expr, input }) => {
             let arrays = to_arrays(expr, input, &mut expr_set)?;
diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/projection_push_down.rs
index aa3cdfb42..80cc1044d 100644
--- a/datafusion/optimizer/src/projection_push_down.rs
+++ b/datafusion/optimizer/src/projection_push_down.rs
@@ -345,18 +345,18 @@ fn optimize_plan(
                 schema.metadata().clone(),
             )?;
 
-            Ok(LogicalPlan::Aggregate(Aggregate {
-                group_expr: group_expr.clone(),
-                aggr_expr: new_aggr_expr,
-                input: Arc::new(optimize_plan(
+            Ok(LogicalPlan::Aggregate(Aggregate::try_new(
+                Arc::new(optimize_plan(
                     _optimizer,
                     input,
                     &new_required_columns,
                     true,
                     _optimizer_config,
                 )?),
-                schema: DFSchemaRef::new(new_schema),
-            }))
+                group_expr.clone(),
+                new_aggr_expr,
+                DFSchemaRef::new(new_schema),
+            )?))
         }
         // scans:
         // * remove un-used columns from the scan projection
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 3244fac8d..e36706caa 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -100,12 +100,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
                     all_field,
                     input.schema().metadata().clone(),
                 )?;
-                let grouped_agg = LogicalPlan::Aggregate(Aggregate {
-                    input: input.clone(),
-                    group_expr: all_group_args,
-                    aggr_expr: Vec::new(),
-                    schema: Arc::new(grouped_schema.clone()),
-                });
+                let grouped_agg = LogicalPlan::Aggregate(Aggregate::try_new(
+                    input.clone(),
+                    all_group_args,
+                    Vec::new(),
+                    Arc::new(grouped_schema.clone()),
+                )?);
                 let grouped_agg = optimize_children(&grouped_agg);
                 let final_agg_schema = Arc::new(DFSchema::new_with_metadata(
                     base_group_expr
@@ -129,13 +129,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
                         ));
                     });
 
-                let final_agg = LogicalPlan::Aggregate(Aggregate {
-                    input: Arc::new(grouped_agg?),
-                    group_expr: group_expr.clone(),
-                    aggr_expr: new_aggr_expr,
-                    schema: final_agg_schema,
-                });
-
+                let final_agg = LogicalPlan::Aggregate(Aggregate::try_new(
+                    Arc::new(grouped_agg?),
+                    group_expr.clone(),
+                    new_aggr_expr,
+                    final_agg_schema,
+                )?);
                 Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
                     alias_expr,
                     Arc::new(final_agg),