You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/05/09 12:57:16 UTC
[arrow-datafusion] branch master updated: Add SQL planner support for `ROLLUP` and `CUBE` grouping set expressions (#2446)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 1fe038fbc Add SQL planner support for `ROLLUP` and `CUBE` grouping set expressions (#2446)
1fe038fbc is described below
commit 1fe038fbc93286bd16c162d5d5d7d34dff1199dc
Author: Andy Grove <ag...@apache.org>
AuthorDate: Mon May 9 06:57:12 2022 -0600
Add SQL planner support for `ROLLUP` and `CUBE` grouping set expressions (#2446)
* Add SQL planner support for ROLLUP and CUBE grouping sets
* prep for review
* fix more todo comments
* code cleanup
* clippy
* fmt and clippy
* revert change
* clippy
---
datafusion/core/src/datasource/listing/helpers.rs | 1 +
datafusion/core/src/logical_plan/builder.rs | 15 ++--
datafusion/core/src/logical_plan/expr.rs | 31 +++++++-
datafusion/core/src/logical_plan/expr_rewriter.rs | 17 ++++
datafusion/core/src/logical_plan/expr_visitor.rs | 14 ++++
.../core/src/optimizer/common_subexpr_eliminate.rs | 28 +++++++
.../core/src/optimizer/projection_push_down.rs | 2 +-
.../core/src/optimizer/simplify_expressions.rs | 1 +
datafusion/core/src/optimizer/utils.rs | 20 +++++
datafusion/core/src/physical_plan/planner.rs | 32 ++++++++
datafusion/core/src/sql/planner.rs | 63 +++++++++++++--
datafusion/core/src/sql/utils.rs | 84 ++++++++++++++++----
datafusion/expr/src/expr.rs | 92 ++++++++++++++++++++++
datafusion/expr/src/expr_schema.rs | 9 +++
14 files changed, 377 insertions(+), 32 deletions(-)
diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs
index 9518986a1..11a91f2ee 100644
--- a/datafusion/core/src/datasource/listing/helpers.rs
+++ b/datafusion/core/src/datasource/listing/helpers.rs
@@ -96,6 +96,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> {
| Expr::InSubquery { .. }
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
+ | Expr::GroupingSet(_)
| Expr::Case { .. } => Recursion::Continue(self),
Expr::ScalarFunction { fun, .. } => self.visit_volatility(fun.volatility()),
diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs
index 8a0ea6d66..80ebfadd1 100644
--- a/datafusion/core/src/logical_plan/builder.rs
+++ b/datafusion/core/src/logical_plan/builder.rs
@@ -43,7 +43,8 @@ use std::{
sync::Arc,
};
-use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
+use super::{Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
+use crate::logical_plan::expr::exprlist_to_fields;
use crate::logical_plan::{
columnize_expr, normalize_col, normalize_cols, provider_as_source,
rewrite_sort_cols_by_aggs, Column, CrossJoin, DFField, DFSchema, DFSchemaRef, Limit,
@@ -557,7 +558,7 @@ impl LogicalPlanBuilder {
expr.extend(missing_exprs);
let new_schema = DFSchema::new_with_metadata(
- exprlist_to_fields(&expr, input_schema)?,
+ exprlist_to_fields(&expr, &input)?,
input_schema.metadata().clone(),
)?;
@@ -629,7 +630,7 @@ impl LogicalPlanBuilder {
.map(|f| Expr::Column(f.qualified_column()))
.collect();
let new_schema = DFSchema::new_with_metadata(
- exprlist_to_fields(&new_expr, schema)?,
+ exprlist_to_fields(&new_expr, &self.plan)?,
schema.metadata().clone(),
)?;
@@ -843,8 +844,7 @@ impl LogicalPlanBuilder {
let window_expr = normalize_cols(window_expr, &self.plan)?;
let all_expr = window_expr.iter();
validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?;
- let mut window_fields: Vec<DFField> =
- exprlist_to_fields(all_expr, self.plan.schema())?;
+ let mut window_fields: Vec<DFField> = exprlist_to_fields(all_expr, &self.plan)?;
window_fields.extend_from_slice(self.plan.schema().fields());
Ok(Self::from(LogicalPlan::Window(Window {
input: Arc::new(self.plan.clone()),
@@ -869,7 +869,7 @@ impl LogicalPlanBuilder {
let all_expr = group_expr.iter().chain(aggr_expr.iter());
validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?;
let aggr_schema = DFSchema::new_with_metadata(
- exprlist_to_fields(all_expr, self.plan.schema())?,
+ exprlist_to_fields(all_expr, &self.plan)?,
self.plan.schema().metadata().clone(),
)?;
Ok(Self::from(LogicalPlan::Aggregate(Aggregate {
@@ -1126,13 +1126,14 @@ pub fn project_with_alias(
}
validate_unique_names("Projections", projected_expr.iter(), input_schema)?;
let input_schema = DFSchema::new_with_metadata(
- exprlist_to_fields(&projected_expr, input_schema)?,
+ exprlist_to_fields(&projected_expr, &plan)?,
plan.schema().metadata().clone(),
)?;
let schema = match alias {
Some(ref alias) => input_schema.replace_qualifier(alias.as_str()),
None => input_schema,
};
+
Ok(LogicalPlan::Projection(Projection {
expr: projected_expr,
input: Arc::new(plan.clone()),
diff --git a/datafusion/core/src/logical_plan/expr.rs b/datafusion/core/src/logical_plan/expr.rs
index 673345c69..3ffc1894e 100644
--- a/datafusion/core/src/logical_plan/expr.rs
+++ b/datafusion/core/src/logical_plan/expr.rs
@@ -22,14 +22,15 @@ pub use super::Operator;
use crate::error::Result;
use crate::logical_plan::ExprSchemable;
use crate::logical_plan::{DFField, DFSchema};
+use crate::sql::utils::find_columns_referenced_by_expr;
use arrow::datatypes::DataType;
pub use datafusion_common::{Column, ExprSchema};
pub use datafusion_expr::expr_fn::*;
-use datafusion_expr::AccumulatorFunctionImplementation;
use datafusion_expr::BuiltinScalarFunction;
pub use datafusion_expr::Expr;
use datafusion_expr::StateTypeFunction;
pub use datafusion_expr::{lit, lit_timestamp_nano, Literal};
+use datafusion_expr::{AccumulatorFunctionImplementation, LogicalPlan};
use datafusion_expr::{AggregateUDF, ScalarUDF};
use datafusion_expr::{
ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility,
@@ -138,9 +139,33 @@ pub fn create_udaf(
/// Create field meta-data from an expression, for use in a result set schema
pub fn exprlist_to_fields<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
- input_schema: &DFSchema,
+ plan: &LogicalPlan,
) -> Result<Vec<DFField>> {
- expr.into_iter().map(|e| e.to_field(input_schema)).collect()
+ match plan {
+ LogicalPlan::Aggregate(agg) => {
+ let group_expr: Vec<Column> = agg
+ .group_expr
+ .iter()
+ .flat_map(find_columns_referenced_by_expr)
+ .collect();
+ let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
+ let mut fields = vec![];
+ for expr in &exprs {
+ match expr {
+ Expr::Column(c) if group_expr.iter().any(|x| x == c) => {
+ // resolve against schema of input to aggregate
+ fields.push(expr.to_field(agg.input.schema())?);
+ }
+ _ => fields.push(expr.to_field(plan.schema())?),
+ }
+ }
+ Ok(fields)
+ }
+ _ => {
+ let input_schema = &plan.schema();
+ expr.into_iter().map(|e| e.to_field(input_schema)).collect()
+ }
+ }
}
/// Calls a named built in function
diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs b/datafusion/core/src/logical_plan/expr_rewriter.rs
index 4e9476899..1f24556ea 100644
--- a/datafusion/core/src/logical_plan/expr_rewriter.rs
+++ b/datafusion/core/src/logical_plan/expr_rewriter.rs
@@ -24,6 +24,7 @@ use crate::logical_plan::ExprSchemable;
use crate::logical_plan::LogicalPlan;
use datafusion_common::Column;
use datafusion_common::Result;
+use datafusion_expr::expr::GroupingSet;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
@@ -215,6 +216,22 @@ impl ExprRewritable for Expr {
fun,
distinct,
},
+ Expr::GroupingSet(grouping_set) => match grouping_set {
+ GroupingSet::Rollup(exprs) => {
+ Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?))
+ }
+ GroupingSet::Cube(exprs) => {
+ Expr::GroupingSet(GroupingSet::Cube(rewrite_vec(exprs, rewriter)?))
+ }
+ GroupingSet::GroupingSets(lists_of_exprs) => {
+ Expr::GroupingSet(GroupingSet::GroupingSets(
+ lists_of_exprs
+ .iter()
+ .map(|exprs| rewrite_vec(exprs.clone(), rewriter))
+ .collect::<Result<Vec<_>>>()?,
+ ))
+ }
+ },
Expr::AggregateUDF { args, fun } => Expr::AggregateUDF {
args: rewrite_vec(args, rewriter)?,
fun,
diff --git a/datafusion/core/src/logical_plan/expr_visitor.rs b/datafusion/core/src/logical_plan/expr_visitor.rs
index 7c578da19..24acb65bc 100644
--- a/datafusion/core/src/logical_plan/expr_visitor.rs
+++ b/datafusion/core/src/logical_plan/expr_visitor.rs
@@ -19,6 +19,7 @@
use super::Expr;
use datafusion_common::Result;
+use datafusion_expr::expr::GroupingSet;
/// Controls how the visitor recursion should proceed.
pub enum Recursion<V: ExpressionVisitor> {
@@ -103,6 +104,19 @@ impl ExprVisitable for Expr {
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
+ Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs
+ .iter()
+ .fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
+ Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs
+ .iter()
+ .fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
+ Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
+ lists_of_exprs.iter().fold(Ok(visitor), |v, exprs| {
+ v.and_then(|v| {
+ exprs.iter().fold(Ok(v), |v, e| v.and_then(|v| e.accept(v)))
+ })
+ })
+ }
Expr::Column(_)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
diff --git a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs
index a9983cdf1..967ef58b3 100644
--- a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs
+++ b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs
@@ -29,6 +29,7 @@ use crate::logical_plan::{
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use arrow::datatypes::DataType;
+use datafusion_expr::expr::GroupingSet;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
@@ -482,6 +483,33 @@ impl ExprIdentifierVisitor<'_> {
desc.push_str("GetIndexedField-");
desc.push_str(&key.to_string());
}
+ Expr::GroupingSet(grouping_set) => match grouping_set {
+ GroupingSet::Rollup(exprs) => {
+ desc.push_str("Rollup");
+ for expr in exprs {
+ desc.push('-');
+ desc.push_str(&Self::desc_expr(expr));
+ }
+ }
+ GroupingSet::Cube(exprs) => {
+ desc.push_str("Cube");
+ for expr in exprs {
+ desc.push('-');
+ desc.push_str(&Self::desc_expr(expr));
+ }
+ }
+ GroupingSet::GroupingSets(lists_of_exprs) => {
+ desc.push_str("GroupingSets");
+ for exprs in lists_of_exprs {
+ desc.push('(');
+ for expr in exprs {
+ desc.push('-');
+ desc.push_str(&Self::desc_expr(expr));
+ }
+ desc.push(')');
+ }
+ }
+ },
}
desc
diff --git a/datafusion/core/src/optimizer/projection_push_down.rs b/datafusion/core/src/optimizer/projection_push_down.rs
index 5062082e8..0979d8f5b 100644
--- a/datafusion/core/src/optimizer/projection_push_down.rs
+++ b/datafusion/core/src/optimizer/projection_push_down.rs
@@ -810,7 +810,7 @@ mod tests {
// that the Column references are unqualified (e.g. their
// relation is `None`). PlanBuilder resolves the expressions
let expr = vec![col("a"), col("b")];
- let projected_fields = exprlist_to_fields(&expr, input_schema).unwrap();
+ let projected_fields = exprlist_to_fields(&expr, &table_scan).unwrap();
let projected_schema = DFSchema::new_with_metadata(
projected_fields,
input_schema.metadata().clone(),
diff --git a/datafusion/core/src/optimizer/simplify_expressions.rs b/datafusion/core/src/optimizer/simplify_expressions.rs
index 4dfbb6eb6..e9694ebc5 100644
--- a/datafusion/core/src/optimizer/simplify_expressions.rs
+++ b/datafusion/core/src/optimizer/simplify_expressions.rs
@@ -380,6 +380,7 @@ impl<'a> ConstEvaluator<'a> {
| Expr::ScalarSubquery(_)
| Expr::WindowFunction { .. }
| Expr::Sort { .. }
+ | Expr::GroupingSet(_)
| Expr::Wildcard
| Expr::QualifiedWildcard { .. } => false,
Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()),
diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs
index 48855df9f..2c56b5f89 100644
--- a/datafusion/core/src/optimizer/utils.rs
+++ b/datafusion/core/src/optimizer/utils.rs
@@ -36,6 +36,7 @@ use crate::{
logical_plan::ExpressionVisitor,
};
use datafusion_common::DFSchema;
+use datafusion_expr::expr::GroupingSet;
use std::{collections::HashSet, sync::Arc};
const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
@@ -83,6 +84,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
+ | Expr::GroupingSet(_)
| Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::Exists { .. }
@@ -323,6 +325,13 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
| Expr::ScalarUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => Ok(args.clone()),
+ Expr::GroupingSet(grouping_set) => match grouping_set {
+ GroupingSet::Rollup(exprs) => Ok(exprs.clone()),
+ GroupingSet::Cube(exprs) => Ok(exprs.clone()),
+ GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
+ "GroupingSets are not supported yet".to_string(),
+ )),
+ },
Expr::WindowFunction {
args,
partition_by,
@@ -458,6 +467,17 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions.to_vec(),
}),
+ Expr::GroupingSet(grouping_set) => match grouping_set {
+ GroupingSet::Rollup(_exprs) => {
+ Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
+ }
+ GroupingSet::Cube(_exprs) => {
+ Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
+ }
+ GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
+ "GroupingSets are not supported yet".to_string(),
+ )),
+ },
Expr::Case { .. } => {
let mut base_expr: Option<Box<Expr>> = None;
let mut when_then: Vec<(Box<Expr>, Box<Expr>)> = vec![];
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 85fb7d424..f6b3842f2 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -62,6 +62,7 @@ use arrow::compute::SortOptions;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::{compute::can_cast_types, datatypes::DataType};
use async_trait::async_trait;
+use datafusion_expr::expr::GroupingSet;
use datafusion_physical_expr::expressions::DateIntervalExpr;
use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
@@ -174,6 +175,37 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
}
Ok(format!("{}({})", fun.name, names.join(",")))
}
+ Expr::GroupingSet(grouping_set) => match grouping_set {
+ GroupingSet::Rollup(exprs) => Ok(format!(
+ "ROLLUP ({})",
+ exprs
+ .iter()
+ .map(|e| create_physical_name(e, false))
+ .collect::<Result<Vec<_>>>()?
+ .join(", ")
+ )),
+ GroupingSet::Cube(exprs) => Ok(format!(
+ "CUBE ({})",
+ exprs
+ .iter()
+ .map(|e| create_physical_name(e, false))
+ .collect::<Result<Vec<_>>>()?
+ .join(", ")
+ )),
+ GroupingSet::GroupingSets(lists_of_exprs) => {
+ let mut strings = vec![];
+ for exprs in lists_of_exprs {
+ let exprs_str = exprs
+ .iter()
+ .map(|e| create_physical_name(e, false))
+ .collect::<Result<Vec<_>>>()?
+ .join(", ");
+ strings.push(format!("({})", exprs_str));
+ }
+ Ok(format!("GROUPING SETS ({})", strings.join(", ")))
+ }
+ },
+
Expr::InList {
expr,
list,
diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs
index 33391d91e..af8329018 100644
--- a/datafusion/core/src/sql/planner.rs
+++ b/datafusion/core/src/sql/planner.rs
@@ -50,6 +50,7 @@ use datafusion_expr::{window_function::WindowFunction, BuiltinScalarFunction};
use hashbrown::HashMap;
use datafusion_common::field_not_found;
+use datafusion_expr::expr::GroupingSet;
use datafusion_expr::logical_plan::{Filter, Subquery};
use sqlparser::ast::{
BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, FunctionArg,
@@ -1156,11 +1157,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// combine the original grouping and aggregate expressions into one list (note that
// we do not add the "having" expression since that is not part of the projection)
- let aggr_projection_exprs = group_by_exprs
- .iter()
- .chain(aggr_exprs.iter())
- .cloned()
- .collect::<Vec<Expr>>();
+ let mut aggr_projection_exprs = vec![];
+ for expr in &group_by_exprs {
+ match expr {
+ Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
+ aggr_projection_exprs.extend_from_slice(exprs)
+ }
+ Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
+ aggr_projection_exprs.extend_from_slice(exprs)
+ }
+ Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
+ for exprs in lists_of_exprs {
+ aggr_projection_exprs.extend_from_slice(exprs)
+ }
+ }
+ _ => aggr_projection_exprs.push(expr.clone()),
+ }
+ }
+ aggr_projection_exprs.extend_from_slice(&aggr_exprs);
// now attempt to resolve columns and replace with fully-qualified columns
let aggr_projection_exprs = aggr_projection_exprs
@@ -1885,10 +1899,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
normalize_ident(&function.name.0[0])
};
- // first, scalar built-in
- if let Ok(fun) = BuiltinScalarFunction::from_str(&name) {
+ // first, check SQL reserved words
+ if name == "rollup" {
+ let args = self.function_args_to_expr(function.args, schema)?;
+ return Ok(Expr::GroupingSet(GroupingSet::Rollup(args)));
+ } else if name == "cube" {
let args = self.function_args_to_expr(function.args, schema)?;
+ return Ok(Expr::GroupingSet(GroupingSet::Cube(args)));
+ }
+ // next, scalar built-in
+ if let Ok(fun) = BuiltinScalarFunction::from_str(&name) {
+ let args = self.function_args_to_expr(function.args, schema)?;
return Ok(Expr::ScalarFunction { fun, args });
};
@@ -4654,6 +4676,33 @@ mod tests {
quick_test(sql, &expected)
}
+ #[tokio::test]
+ async fn aggregate_with_rollup() {
+ let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)";
+ let expected = "Projection: #person.id, #person.state, #person.age, #COUNT(UInt8(1))\
+ \n Aggregate: groupBy=[[#person.id, ROLLUP (#person.state, #person.age)]], aggr=[[COUNT(UInt8(1))]]\
+ \n TableScan: person projection=None";
+ quick_test(sql, expected);
+ }
+
+ #[tokio::test]
+ async fn aggregate_with_cube() {
+ let sql =
+ "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)";
+ let expected = "Projection: #person.id, #person.state, #person.age, #COUNT(UInt8(1))\
+ \n Aggregate: groupBy=[[#person.id, CUBE (#person.state, #person.age)]], aggr=[[COUNT(UInt8(1))]]\
+ \n TableScan: person projection=None";
+ quick_test(sql, expected);
+ }
+
+ #[ignore] // see https://github.com/apache/arrow-datafusion/issues/2469
+ #[tokio::test]
+ async fn aggregate_with_grouping_sets() {
+ let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))";
+ let expected = "TBD";
+ quick_test(sql, expected);
+ }
+
fn assert_field_not_found(err: DataFusionError, name: &str) {
match err {
DataFusionError::SchemaError { .. } => {
diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs
index 0293e2410..b2cf1f698 100644
--- a/datafusion/core/src/sql/utils.rs
+++ b/datafusion/core/src/sql/utils.rs
@@ -27,6 +27,7 @@ use crate::{
error::{DataFusionError, Result},
logical_plan::{Column, ExpressionVisitor, Recursion},
};
+use datafusion_expr::expr::GroupingSet;
use std::collections::HashMap;
/// Collect all deeply nested `Expr::AggregateFunction` and
@@ -100,7 +101,7 @@ impl ExpressionVisitor for ColumnCollector {
}
}
-fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
+pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
// As the `ExpressionVisitor` impl above always returns Ok, this
// "can't" error
let ColumnCollector { exprs } = e
@@ -235,22 +236,49 @@ pub(crate) fn check_columns_satisfy_exprs(
"Expr::Column are required".to_string(),
)),
})?;
-
- for e in &find_column_exprs(exprs) {
- if !columns.contains(e) {
- return Err(DataFusionError::Plan(format!(
- "{}: Expression {:?} could not be resolved from available columns: {}",
- message_prefix,
- e,
- columns
- .iter()
- .map(|e| format!("{}", e))
- .collect::<Vec<String>>()
- .join(", ")
- )));
+ let column_exprs = find_column_exprs(exprs);
+ for e in &column_exprs {
+ match e {
+ Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
+ for e in exprs {
+ check_column_satisfies_expr(columns, e, message_prefix)?;
+ }
+ }
+ Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
+ for e in exprs {
+ check_column_satisfies_expr(columns, e, message_prefix)?;
+ }
+ }
+ Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
+ for exprs in lists_of_exprs {
+ for e in exprs {
+ check_column_satisfies_expr(columns, e, message_prefix)?;
+ }
+ }
+ }
+ _ => check_column_satisfies_expr(columns, e, message_prefix)?,
}
}
+ Ok(())
+}
+fn check_column_satisfies_expr(
+ columns: &[Expr],
+ expr: &Expr,
+ message_prefix: &str,
+) -> Result<()> {
+ if !columns.contains(expr) {
+ return Err(DataFusionError::Plan(format!(
+ "{}: Expression {:?} could not be resolved from available columns: {}",
+ message_prefix,
+ expr,
+ columns
+ .iter()
+ .map(|e| format!("{}", e))
+ .collect::<Vec<String>>()
+ .join(", ")
+ )));
+ }
Ok(())
}
@@ -456,6 +484,34 @@ where
expr: Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?),
key: key.clone(),
}),
+ Expr::GroupingSet(set) => match set {
+ GroupingSet::Rollup(exprs) => Ok(Expr::GroupingSet(GroupingSet::Rollup(
+ exprs
+ .iter()
+ .map(|e| clone_with_replacement(e, replacement_fn))
+ .collect::<Result<Vec<Expr>>>()?,
+ ))),
+ GroupingSet::Cube(exprs) => Ok(Expr::GroupingSet(GroupingSet::Cube(
+ exprs
+ .iter()
+ .map(|e| clone_with_replacement(e, replacement_fn))
+ .collect::<Result<Vec<Expr>>>()?,
+ ))),
+ GroupingSet::GroupingSets(lists_of_exprs) => {
+ let mut new_lists_of_exprs = vec![];
+ for exprs in lists_of_exprs {
+ new_lists_of_exprs.push(
+ exprs
+ .iter()
+ .map(|e| clone_with_replacement(e, replacement_fn))
+ .collect::<Result<Vec<Expr>>>()?,
+ );
+ }
+ Ok(Expr::GroupingSet(GroupingSet::GroupingSets(
+ new_lists_of_exprs,
+ )))
+ }
+ },
},
}
}
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 4d88ed815..c1c61d1ff 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -249,6 +249,24 @@ pub enum Expr {
Wildcard,
/// Represents a reference to all fields in a specific schema.
QualifiedWildcard { qualifier: String },
+ /// List of grouping set expressions. Only valid in the context of an aggregate
+ /// GROUP BY expression list
+ GroupingSet(GroupingSet),
+}
+
+/// Grouping sets
+/// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS
+/// for Postgres definition.
+/// See https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-groupby.html
+/// for Apache Spark definition.
+#[derive(Clone, PartialEq, Hash)]
+pub enum GroupingSet {
+ /// Rollup grouping sets
+ Rollup(Vec<Expr>),
+ /// Cube grouping sets
+ Cube(Vec<Expr>),
+ /// User-defined grouping sets
+ GroupingSets(Vec<Vec<Expr>>),
}
/// Fixed seed for the hashing so that Ords are consistent across runs
@@ -556,6 +574,51 @@ impl fmt::Debug for Expr {
Expr::GetIndexedField { ref expr, key } => {
write!(f, "({:?})[{}]", expr, key)
}
+ Expr::GroupingSet(grouping_sets) => match grouping_sets {
+ GroupingSet::Rollup(exprs) => {
+ // ROLLUP (c0, c1, c2)
+ write!(
+ f,
+ "ROLLUP ({})",
+ exprs
+ .iter()
+ .map(|e| format!("{}", e))
+ .collect::<Vec<String>>()
+ .join(", ")
+ )
+ }
+ GroupingSet::Cube(exprs) => {
+ // CUBE (c0, c1, c2)
+ write!(
+ f,
+ "CUBE ({})",
+ exprs
+ .iter()
+ .map(|e| format!("{}", e))
+ .collect::<Vec<String>>()
+ .join(", ")
+ )
+ }
+ GroupingSet::GroupingSets(lists_of_exprs) => {
+ // GROUPING SETS ((c0), (c1, c2), (c3, c4))
+ write!(
+ f,
+ "GROUPING SETS ({})",
+ lists_of_exprs
+ .iter()
+ .map(|exprs| format!(
+ "({})",
+ exprs
+ .iter()
+ .map(|e| format!("{}", e))
+ .collect::<Vec<String>>()
+ .join(", ")
+ ))
+ .collect::<Vec<String>>()
+ .join(", ")
+ )
+ }
+ },
}
}
}
@@ -710,6 +773,26 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
}
Ok(format!("{}({})", fun.name, names.join(",")))
}
+ Expr::GroupingSet(grouping_set) => match grouping_set {
+ GroupingSet::Rollup(exprs) => Ok(format!(
+ "ROLLUP ({})",
+ create_names(exprs.as_slice(), input_schema)?
+ )),
+ GroupingSet::Cube(exprs) => Ok(format!(
+ "CUBE ({})",
+ create_names(exprs.as_slice(), input_schema)?
+ )),
+ GroupingSet::GroupingSets(lists_of_exprs) => {
+ let mut list_of_names = vec![];
+ for exprs in lists_of_exprs {
+ list_of_names.push(format!(
+ "({})",
+ create_names(exprs.as_slice(), input_schema)?
+ ));
+ }
+ Ok(format!("GROUPING SETS ({})", list_of_names.join(", ")))
+ }
+ },
Expr::InList {
expr,
list,
@@ -750,6 +833,15 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
}
}
+/// Create a comma separated list of names from a list of expressions
+fn create_names(exprs: &[Expr], input_schema: &DFSchema) -> Result<String> {
+ Ok(exprs
+ .iter()
+ .map(|e| create_name(e, input_schema))
+ .collect::<Result<Vec<String>>>()?
+ .join(", "))
+}
+
#[cfg(test)]
mod test {
use crate::expr_fn::col;
diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index b932eefa0..2433024e3 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -124,6 +124,10 @@ impl ExprSchemable for Expr {
"QualifiedWildcard expressions are not valid in a logical query plan"
.to_owned(),
)),
+ Expr::GroupingSet(_) => {
+ // grouping sets do not really have a type and do not appear in projections
+ Ok(DataType::Null)
+ }
Expr::GetIndexedField { ref expr, key } => {
let data_type = expr.get_type(schema)?;
@@ -198,6 +202,11 @@ impl ExprSchemable for Expr {
let data_type = expr.get_type(input_schema)?;
get_indexed_field(&data_type, key).map(|x| x.is_nullable())
}
+ Expr::GroupingSet(_) => {
+ // grouping sets do not really have the concept of nullable and do not appear
+ // in projections
+ Ok(true)
+ }
}
}