You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/08/31 14:58:38 UTC
[arrow-datafusion] branch master updated: Add `Aggregate::try new` with validation checks (#3286)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 3d37016de Add `Aggregate::try new` with validation checks (#3286)
3d37016de is described below
commit 3d37016de3647d90e6f78fd0e106142923799969
Author: Andy Grove <an...@gmail.com>
AuthorDate: Wed Aug 31 08:58:32 2022 -0600
Add `Aggregate::try new` with validation checks (#3286)
* Add Aggregate::try_new with validation checks
* fix calculation of number of grouping expressions
* use suggested error message
---
datafusion/core/src/physical_plan/planner.rs | 10 ++++----
datafusion/expr/src/logical_plan/builder.rs | 8 +++---
datafusion/expr/src/logical_plan/plan.rs | 30 +++++++++++++++++++++-
datafusion/expr/src/utils.rs | 28 +++++++++++++++-----
.../optimizer/src/common_subexpr_eliminate.rs | 12 ++++-----
datafusion/optimizer/src/projection_push_down.rs | 12 ++++-----
.../optimizer/src/single_distinct_to_groupby.rs | 25 +++++++++---------
7 files changed, 84 insertions(+), 41 deletions(-)
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 747cd1a20..a67e2dac7 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -645,12 +645,12 @@ impl DefaultPhysicalPlanner {
LogicalPlan::Distinct(Distinct {input}) => {
// Convert distinct to groupby with no aggregations
let group_expr = expand_wildcard(input.schema(), input)?;
- let aggregate = LogicalPlan::Aggregate(Aggregate {
- input: input.clone(),
+ let aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
+ input.clone(),
group_expr,
- aggr_expr: vec![],
- schema: input.schema().clone()
- }
+ vec![],
+ input.schema().clone()
+ )?
);
Ok(self.create_initial_plan(&aggregate, session_state).await?)
}
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 2946a74af..41ba95140 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -701,12 +701,12 @@ impl LogicalPlanBuilder {
exprlist_to_fields(all_expr, &self.plan)?,
self.plan.schema().metadata().clone(),
)?;
- Ok(Self::from(LogicalPlan::Aggregate(Aggregate {
- input: Arc::new(self.plan.clone()),
+ Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::new(self.plan.clone()),
group_expr,
aggr_expr,
- schema: DFSchemaRef::new(aggr_schema),
- })))
+ DFSchemaRef::new(aggr_schema),
+ )?)))
}
/// Create an expression to represent the explanation of the plan
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 2d5eb4680..cec55bfc1 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -17,7 +17,7 @@
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
-use crate::utils::exprlist_to_fields;
+use crate::utils::{exprlist_to_fields, grouping_set_expr_count};
use crate::{Expr, TableProviderFilterPushDown, TableSource};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::{plan_err, Column, DFSchema, DFSchemaRef, DataFusionError};
@@ -1314,6 +1314,34 @@ pub struct Aggregate {
}
impl Aggregate {
+ pub fn try_new(
+ input: Arc<LogicalPlan>,
+ group_expr: Vec<Expr>,
+ aggr_expr: Vec<Expr>,
+ schema: DFSchemaRef,
+ ) -> datafusion_common::Result<Self> {
+ if group_expr.is_empty() && aggr_expr.is_empty() {
+ return Err(DataFusionError::Plan(
+ "Aggregate requires at least one grouping or aggregate expression"
+ .to_string(),
+ ));
+ }
+ let group_expr_count = grouping_set_expr_count(&group_expr)?;
+ if schema.fields().len() != group_expr_count + aggr_expr.len() {
+ return Err(DataFusionError::Plan(format!(
+ "Aggregate schema has wrong number of fields. Expected {} got {}",
+ group_expr_count + aggr_expr.len(),
+ schema.fields().len()
+ )));
+ }
+ Ok(Self {
+ input,
+ group_expr,
+ aggr_expr,
+ schema,
+ })
+ }
+
pub fn try_from_plan(plan: &LogicalPlan) -> datafusion_common::Result<&Aggregate> {
match plan {
LogicalPlan::Aggregate(it) => Ok(it),
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 7d3f78b8f..e748536d7 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -45,6 +45,22 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result
Ok(())
}
+/// Count the number of distinct exprs in a list of group by expressions. If the
+/// first element is a `GroupingSet` expression then it must be the only expr.
+pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
+ if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
+ if group_expr.len() > 1 {
+ return Err(DataFusionError::Plan(
+ "Invalid group by expressions, GroupingSet must be the only expression"
+ .to_string(),
+ ));
+ }
+ Ok(grouping_set.distinct_expr().len())
+ } else {
+ Ok(group_expr.len())
+ }
+}
+
/// Find all distinct exprs in a list of group by expressions. If the
/// first element is a `GroupingSet` expression then it must be the only expr.
pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
@@ -395,12 +411,12 @@ pub fn from_plan(
})),
LogicalPlan::Aggregate(Aggregate {
group_expr, schema, ..
- }) => Ok(LogicalPlan::Aggregate(Aggregate {
- group_expr: expr[0..group_expr.len()].to_vec(),
- aggr_expr: expr[group_expr.len()..].to_vec(),
- input: Arc::new(inputs[0].clone()),
- schema: schema.clone(),
- })),
+ }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::new(inputs[0].clone()),
+ expr[0..group_expr.len()].to_vec(),
+ expr[group_expr.len()..].to_vec(),
+ schema.clone(),
+ )?)),
LogicalPlan::Sort(Sort { .. }) => Ok(LogicalPlan::Sort(Sort {
expr: expr.to_vec(),
input: Arc::new(inputs[0].clone()),
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 305283d99..f015aeaa0 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -189,12 +189,12 @@ fn optimize(
let new_aggr_expr = new_expr.pop().unwrap();
let new_group_expr = new_expr.pop().unwrap();
- Ok(LogicalPlan::Aggregate(Aggregate {
- input: Arc::new(new_input),
- group_expr: new_group_expr,
- aggr_expr: new_aggr_expr,
- schema: schema.clone(),
- }))
+ Ok(LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::new(new_input),
+ new_group_expr,
+ new_aggr_expr,
+ schema.clone(),
+ )?))
}
LogicalPlan::Sort(Sort { expr, input }) => {
let arrays = to_arrays(expr, input, &mut expr_set)?;
diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/projection_push_down.rs
index aa3cdfb42..80cc1044d 100644
--- a/datafusion/optimizer/src/projection_push_down.rs
+++ b/datafusion/optimizer/src/projection_push_down.rs
@@ -345,18 +345,18 @@ fn optimize_plan(
schema.metadata().clone(),
)?;
- Ok(LogicalPlan::Aggregate(Aggregate {
- group_expr: group_expr.clone(),
- aggr_expr: new_aggr_expr,
- input: Arc::new(optimize_plan(
+ Ok(LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::new(optimize_plan(
_optimizer,
input,
&new_required_columns,
true,
_optimizer_config,
)?),
- schema: DFSchemaRef::new(new_schema),
- }))
+ group_expr.clone(),
+ new_aggr_expr,
+ DFSchemaRef::new(new_schema),
+ )?))
}
// scans:
// * remove un-used columns from the scan projection
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 3244fac8d..e36706caa 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -100,12 +100,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
all_field,
input.schema().metadata().clone(),
)?;
- let grouped_agg = LogicalPlan::Aggregate(Aggregate {
- input: input.clone(),
- group_expr: all_group_args,
- aggr_expr: Vec::new(),
- schema: Arc::new(grouped_schema.clone()),
- });
+ let grouped_agg = LogicalPlan::Aggregate(Aggregate::try_new(
+ input.clone(),
+ all_group_args,
+ Vec::new(),
+ Arc::new(grouped_schema.clone()),
+ )?);
let grouped_agg = optimize_children(&grouped_agg);
let final_agg_schema = Arc::new(DFSchema::new_with_metadata(
base_group_expr
@@ -129,13 +129,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
));
});
- let final_agg = LogicalPlan::Aggregate(Aggregate {
- input: Arc::new(grouped_agg?),
- group_expr: group_expr.clone(),
- aggr_expr: new_aggr_expr,
- schema: final_agg_schema,
- });
-
+ let final_agg = LogicalPlan::Aggregate(Aggregate::try_new(
+ Arc::new(grouped_agg?),
+ group_expr.clone(),
+ new_aggr_expr,
+ final_agg_schema,
+ )?);
Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
alias_expr,
Arc::new(final_agg),