You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/06/13 19:55:32 UTC
[arrow-datafusion] branch master updated: Support for GROUPING SETS/CUBE/ROLLUP (#2716)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new ca5339bfd Support for GROUPING SETS/CUBE/ROLLUP (#2716)
ca5339bfd is described below
commit ca5339bfd27677c165263aa01f263c8fba886a45
Author: Dan Harris <13...@users.noreply.github.com>
AuthorDate: Mon Jun 13 15:55:28 2022 -0400
Support for GROUPING SETS/CUBE/ROLLUP (#2716)
* WIP
* Implement for non-row based accumulators
* Non-row aggregations
* Map logical plan and add some basic tests
* Handle grouping sets in various optimize passes.
* Implemented create_cube_expr and create_rollup_expr functions
* Cleanup and ignore SingleDistinctToGroupBy when using grouping sets for now
* Handle grouping sets in SingleDistinctToGroupBy
* Add more tests and burn the boats
* Fix(ish) partitioning
* Serialization for grouping set exprs
* fixed bug with create_cube_expr function
* fixed bug with create_cube_expr function
* Fixed bug in row-based-aggregation
* Added unit tests for test_create_rollup_expr and test_create_cube_expr
* Formatting
* Tests, linter fixes and docs
* Linting
* Better encoding which avoids evaluating grouping expressions redundantly
* Remove commented code
* Apply suggestions from code review
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* PR Comments: Rename PhysicalGroupingSet -> PhysicalGroupBy and clarify doc comment
* Disable single_distinct_to_groupby for grouping sets for now and add unit tests for single distinct queries.
* PR comments
* Remove old comment
* Return PhysicalGroupBy from AggregateExec::group_expr
Co-authored-by: Ryan Tomczik <ry...@coralogix.com>
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
.../src/physical_optimizer/aggregate_statistics.rs | 26 +-
.../core/src/physical_optimizer/repartition.rs | 8 +-
.../core/src/physical_plan/aggregates/hash.rs | 276 +++++------
.../core/src/physical_plan/aggregates/mod.rs | 393 ++++++++++++++--
.../core/src/physical_plan/aggregates/row_hash.rs | 288 ++++++------
datafusion/core/src/physical_plan/planner.rs | 398 +++++++++++++++-
datafusion/core/tests/dataframe.rs | 219 ++++++++-
datafusion/core/tests/sql/aggregates.rs | 519 +++++++++++++++++++++
datafusion/expr/src/expr.rs | 21 +
datafusion/expr/src/expr_fn.rs | 16 +
datafusion/expr/src/logical_plan/builder.rs | 9 +-
datafusion/expr/src/utils.rs | 16 +
datafusion/optimizer/src/projection_push_down.rs | 6 +-
.../optimizer/src/single_distinct_to_groupby.rs | 101 +++-
datafusion/proto/proto/datafusion.proto | 24 +
datafusion/proto/src/from_proto.rs | 33 +-
datafusion/proto/src/lib.rs | 29 ++
datafusion/proto/src/to_proto.rs | 43 +-
18 files changed, 2044 insertions(+), 381 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index bcf4fec07..cafd61d9e 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -265,7 +265,7 @@ mod tests {
use crate::error::Result;
use crate::logical_plan::Operator;
- use crate::physical_plan::aggregates::AggregateExec;
+ use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy};
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use crate::physical_plan::common;
use crate::physical_plan::expressions::Count;
@@ -407,7 +407,7 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
source,
Arc::clone(&schema),
@@ -415,7 +415,7 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
@@ -435,7 +435,7 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
source,
Arc::clone(&schema),
@@ -443,7 +443,7 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
@@ -462,7 +462,7 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
source,
Arc::clone(&schema),
@@ -473,7 +473,7 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
Arc::new(coalesce),
Arc::clone(&schema),
@@ -492,7 +492,7 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
source,
Arc::clone(&schema),
@@ -503,7 +503,7 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
Arc::new(coalesce),
Arc::clone(&schema),
@@ -533,7 +533,7 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
filter,
Arc::clone(&schema),
@@ -541,7 +541,7 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
@@ -576,7 +576,7 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
filter,
Arc::clone(&schema),
@@ -584,7 +584,7 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
- vec![],
+ PhysicalGroupBy::default(),
vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
diff --git a/datafusion/core/src/physical_optimizer/repartition.rs b/datafusion/core/src/physical_optimizer/repartition.rs
index b3b7ba948..e9e14abf6 100644
--- a/datafusion/core/src/physical_optimizer/repartition.rs
+++ b/datafusion/core/src/physical_optimizer/repartition.rs
@@ -242,7 +242,9 @@ mod tests {
use super::*;
use crate::datasource::listing::PartitionedFile;
use crate::datasource::object_store::ObjectStoreUrl;
- use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
+ use crate::physical_plan::aggregates::{
+ AggregateExec, AggregateMode, PhysicalGroupBy,
+ };
use crate::physical_plan::expressions::{col, PhysicalSortExpr};
use crate::physical_plan::file_format::{FileScanConfig, ParquetExec};
use crate::physical_plan::filter::FilterExec;
@@ -305,12 +307,12 @@ mod tests {
Arc::new(
AggregateExec::try_new(
AggregateMode::Final,
- vec![],
+ PhysicalGroupBy::default(),
vec![],
Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
- vec![],
+ PhysicalGroupBy::default(),
vec![],
input,
schema.clone(),
diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs
index 45719260c..ddf9af18f 100644
--- a/datafusion/core/src/physical_plan/aggregates/hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/hash.rs
@@ -29,7 +29,7 @@ use futures::{
use crate::error::Result;
use crate::physical_plan::aggregates::{
- evaluate, evaluate_many, AccumulatorItem, AggregateMode,
+ evaluate_group_by, evaluate_many, AccumulatorItem, AggregateMode, PhysicalGroupBy,
};
use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
@@ -81,7 +81,7 @@ pub(crate) struct GroupedHashAggregateStream {
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
- group_expr: Vec<Arc<dyn PhysicalExpr>>,
+ group_by: PhysicalGroupBy,
baseline_metrics: BaselineMetrics,
random_state: RandomState,
@@ -93,7 +93,7 @@ impl GroupedHashAggregateStream {
pub fn new(
mode: AggregateMode,
schema: SchemaRef,
- group_expr: Vec<Arc<dyn PhysicalExpr>>,
+ group_by: PhysicalGroupBy,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
@@ -104,7 +104,7 @@ impl GroupedHashAggregateStream {
// Assume create_schema() always put group columns in front of aggr columns, we set
// col_idx_base to group expression count.
let aggregate_expressions =
- aggregates::aggregate_expressions(&aggr_expr, &mode, group_expr.len())?;
+ aggregates::aggregate_expressions(&aggr_expr, &mode, group_by.expr.len())?;
timer.done();
@@ -113,7 +113,7 @@ impl GroupedHashAggregateStream {
mode,
input,
aggr_expr,
- group_expr,
+ group_by,
baseline_metrics,
aggregate_expressions,
accumulators: Default::default(),
@@ -144,7 +144,7 @@ impl Stream for GroupedHashAggregateStream {
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
- &this.group_expr,
+ &this.group_by,
&this.aggr_expr,
batch,
&mut this.accumulators,
@@ -165,7 +165,7 @@ impl Stream for GroupedHashAggregateStream {
let result = create_batch_from_map(
&this.mode,
&this.accumulators,
- this.group_expr.len(),
+ this.group_by.expr.len(),
&this.schema,
)
.record_output(&this.baseline_metrics);
@@ -191,152 +191,154 @@ impl RecordBatchStream for GroupedHashAggregateStream {
fn group_aggregate_batch(
mode: &AggregateMode,
random_state: &RandomState,
- group_expr: &[Arc<dyn PhysicalExpr>],
+ group_by: &PhysicalGroupBy,
aggr_expr: &[Arc<dyn AggregateExpr>],
batch: RecordBatch,
accumulators: &mut Accumulators,
aggregate_expressions: &[Vec<Arc<dyn PhysicalExpr>>],
) -> Result<()> {
// evaluate the grouping expressions
- let group_values = evaluate(group_expr, &batch)?;
+ let group_by_values = evaluate_group_by(group_by, &batch)?;
// evaluate the aggregation expressions.
// We could evaluate them after the `take`, but since we need to evaluate all
// of them anyways, it is more performant to do it while they are together.
let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?;
- // 1.1 construct the key from the group values
- // 1.2 construct the mapping key if it does not exist
- // 1.3 add the row' index to `indices`
-
- // track which entries in `accumulators` have rows in this batch to aggregate
- let mut groups_with_rows = vec![];
-
- // 1.1 Calculate the group keys for the group values
- let mut batch_hashes = vec![0; batch.num_rows()];
- create_hashes(&group_values, random_state, &mut batch_hashes)?;
-
- for (row, hash) in batch_hashes.into_iter().enumerate() {
- let Accumulators { map, group_states } = accumulators;
-
- let entry = map.get_mut(hash, |(_hash, group_idx)| {
- // verify that a group that we are inserting with hash is
- // actually the same key value as the group in
- // existing_idx (aka group_values @ row)
- let group_state = &group_states[*group_idx];
- group_values
- .iter()
- .zip(group_state.group_by_values.iter())
- .all(|(array, scalar)| scalar.eq_array(array, row))
- });
-
- match entry {
- // Existing entry for this group value
- Some((_hash, group_idx)) => {
- let group_state = &mut group_states[*group_idx];
- // 1.3
- if group_state.indices.is_empty() {
- groups_with_rows.push(*group_idx);
- };
- group_state.indices.push(row as u32); // remember this row
- }
- // 1.2 Need to create new entry
- None => {
- let accumulator_set = aggregates::create_accumulators(aggr_expr)?;
+ for grouping_set_values in group_by_values {
+ // 1.1 construct the key from the group values
+ // 1.2 construct the mapping key if it does not exist
+ // 1.3 add the row' index to `indices`
+
+ // track which entries in `accumulators` have rows in this batch to aggregate
+ let mut groups_with_rows = vec![];
+
+ // 1.1 Calculate the group keys for the group values
+ let mut batch_hashes = vec![0; batch.num_rows()];
+ create_hashes(&grouping_set_values, random_state, &mut batch_hashes)?;
- // Copy group values out of arrays into `ScalarValue`s
- let group_by_values = group_values
+ for (row, hash) in batch_hashes.into_iter().enumerate() {
+ let Accumulators { map, group_states } = accumulators;
+
+ let entry = map.get_mut(hash, |(_hash, group_idx)| {
+ // verify that a group that we are inserting with hash is
+ // actually the same key value as the group in
+ // existing_idx (aka group_values @ row)
+ let group_state = &group_states[*group_idx];
+ grouping_set_values
.iter()
- .map(|col| ScalarValue::try_from_array(col, row))
- .collect::<Result<Vec<_>>>()?;
-
- // Add new entry to group_states and save newly created index
- let group_state = GroupState {
- group_by_values: group_by_values.into_boxed_slice(),
- accumulator_set,
- indices: vec![row as u32], // 1.3
- };
- let group_idx = group_states.len();
- group_states.push(group_state);
- groups_with_rows.push(group_idx);
-
- // for hasher function, use precomputed hash value
- map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
- }
- };
- }
+ .zip(group_state.group_by_values.iter())
+ .all(|(array, scalar)| scalar.eq_array(array, row))
+ });
+
+ match entry {
+ // Existing entry for this group value
+ Some((_hash, group_idx)) => {
+ let group_state = &mut group_states[*group_idx];
+ // 1.3
+ if group_state.indices.is_empty() {
+ groups_with_rows.push(*group_idx);
+ };
+ group_state.indices.push(row as u32); // remember this row
+ }
+ // 1.2 Need to create new entry
+ None => {
+ let accumulator_set = aggregates::create_accumulators(aggr_expr)?;
+
+ // Copy group values out of arrays into `ScalarValue`s
+ let group_by_values = grouping_set_values
+ .iter()
+ .map(|col| ScalarValue::try_from_array(col, row))
+ .collect::<Result<Vec<_>>>()?;
+
+ // Add new entry to group_states and save newly created index
+ let group_state = GroupState {
+ group_by_values: group_by_values.into_boxed_slice(),
+ accumulator_set,
+ indices: vec![row as u32], // 1.3
+ };
+ let group_idx = group_states.len();
+ group_states.push(group_state);
+ groups_with_rows.push(group_idx);
+
+ // for hasher function, use precomputed hash value
+ map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
+ }
+ };
+ }
- // Collect all indices + offsets based on keys in this vec
- let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
- let mut offsets = vec![0];
- let mut offset_so_far = 0;
- for group_idx in groups_with_rows.iter() {
- let indices = &accumulators.group_states[*group_idx].indices;
- batch_indices.append_slice(indices)?;
- offset_so_far += indices.len();
- offsets.push(offset_so_far);
- }
- let batch_indices = batch_indices.finish();
+ // Collect all indices + offsets based on keys in this vec
+ let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
+ let mut offsets = vec![0];
+ let mut offset_so_far = 0;
+ for group_idx in groups_with_rows.iter() {
+ let indices = &accumulators.group_states[*group_idx].indices;
+ batch_indices.append_slice(indices)?;
+ offset_so_far += indices.len();
+ offsets.push(offset_so_far);
+ }
+ let batch_indices = batch_indices.finish();
- // `Take` all values based on indices into Arrays
- let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
- .iter()
- .map(|array| {
- array
- .iter()
- .map(|array| {
- compute::take(
- array.as_ref(),
- &batch_indices,
- None, // None: no index check
- )
- .unwrap()
- })
- .collect()
- // 2.3
- })
- .collect();
-
- // 2.1 for each key in this batch
- // 2.2 for each aggregation
- // 2.3 `slice` from each of its arrays the keys' values
- // 2.4 update / merge the accumulator with the values
- // 2.5 clear indices
- groups_with_rows
- .iter()
- .zip(offsets.windows(2))
- .try_for_each(|(group_idx, offsets)| {
- let group_state = &mut accumulators.group_states[*group_idx];
- // 2.2
- group_state
- .accumulator_set
- .iter_mut()
- .zip(values.iter())
- .map(|(accumulator, aggr_array)| {
- (
- accumulator,
- aggr_array
- .iter()
- .map(|array| {
- // 2.3
- array.slice(offsets[0], offsets[1] - offsets[0])
- })
- .collect::<Vec<ArrayRef>>(),
- )
- })
- .try_for_each(|(accumulator, values)| match mode {
- AggregateMode::Partial => accumulator.update_batch(&values),
- AggregateMode::FinalPartitioned | AggregateMode::Final => {
- // note: the aggregation here is over states, not values, thus the merge
- accumulator.merge_batch(&values)
- }
- })
- // 2.5
- .and({
- group_state.indices.clear();
- Ok(())
- })
- })?;
+ // `Take` all values based on indices into Arrays
+ let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
+ .iter()
+ .map(|array| {
+ array
+ .iter()
+ .map(|array| {
+ compute::take(
+ array.as_ref(),
+ &batch_indices,
+ None, // None: no index check
+ )
+ .unwrap()
+ })
+ .collect()
+ // 2.3
+ })
+ .collect();
+
+ // 2.1 for each key in this batch
+ // 2.2 for each aggregation
+ // 2.3 `slice` from each of its arrays the keys' values
+ // 2.4 update / merge the accumulator with the values
+ // 2.5 clear indices
+ groups_with_rows
+ .iter()
+ .zip(offsets.windows(2))
+ .try_for_each(|(group_idx, offsets)| {
+ let group_state = &mut accumulators.group_states[*group_idx];
+ // 2.2
+ group_state
+ .accumulator_set
+ .iter_mut()
+ .zip(values.iter())
+ .map(|(accumulator, aggr_array)| {
+ (
+ accumulator,
+ aggr_array
+ .iter()
+ .map(|array| {
+ // 2.3
+ array.slice(offsets[0], offsets[1] - offsets[0])
+ })
+ .collect::<Vec<ArrayRef>>(),
+ )
+ })
+ .try_for_each(|(accumulator, values)| match mode {
+ AggregateMode::Partial => accumulator.update_batch(&values),
+ AggregateMode::FinalPartitioned | AggregateMode::Final => {
+ // note: the aggregation here is over states, not values, thus the merge
+ accumulator.merge_batch(&values)
+ }
+ })
+ // 2.5
+ .and({
+ group_state.indices.clear();
+ Ok(())
+ })
+ })?;
+ }
Ok(())
}
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs
index abe20cdcb..657b6281a 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -37,6 +37,7 @@ use datafusion_physical_expr::{
expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr,
};
use std::any::Any;
+
use std::sync::Arc;
mod hash;
@@ -65,13 +66,93 @@ pub enum AggregateMode {
FinalPartitioned,
}
+/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
+/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
+/// and a single group [false, false].
+/// In the case of `GROUP BY GROUPING SET/CUBE/ROLLUP` the planner will expand the expression
+/// into multiple groups, using null expressions to align each group.
+/// For example, with a group by clause `GROUP BY GROUPING SET ((a,b),(a),(b))` the planner should
+/// create a `PhysicalGroupBy` like
+/// PhysicalGroupBy {
+/// expr: [(col(a), a), (col(b), b)],
+/// null_expr: [(NULL, a), (NULL, b)],
+/// groups: [
+/// [false, false], // (a,b)
+/// [false, true], // (a) <=> (a, NULL)
+/// [true, false] // (b) <=> (NULL, b)
+/// ]
+/// }
+#[derive(Clone, Debug, Default)]
+pub struct PhysicalGroupBy {
+ /// Distinct (Physical Expr, Alias) in the grouping set
+ expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+ /// Corresponding NULL expressions for expr
+ null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+ /// Null mask for each group in this grouping set. Each group is
+ /// composed of either one of the group expressions in expr or a null
+ /// expression in null_expr. If groups[i][j] is true, then the the
+ /// j-th expression in the i-th group is NULL, otherwise it is expr[j].
+ groups: Vec<Vec<bool>>,
+}
+
+impl PhysicalGroupBy {
+ /// Create a new `PhysicalGroupBy`
+ pub fn new(
+ expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+ null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+ groups: Vec<Vec<bool>>,
+ ) -> Self {
+ Self {
+ expr,
+ null_expr,
+ groups,
+ }
+ }
+
+ /// Create a GROUPING SET with only a single group. This is the "standard"
+ /// case when building a plan from an expression such as `GROUP BY a,b,c`
+ pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
+ let num_exprs = expr.len();
+ Self {
+ expr,
+ null_expr: vec![],
+ groups: vec![vec![false; num_exprs]],
+ }
+ }
+
+ /// Returns true if this GROUP BY contains NULL expressions
+ pub fn contains_null(&self) -> bool {
+ self.groups.iter().flatten().any(|is_null| *is_null)
+ }
+
+ /// Returns the group expressions
+ pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
+ &self.expr
+ }
+
+ /// Returns the null expressions
+ pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
+ &self.null_expr
+ }
+
+ /// Returns the group null masks
+ pub fn groups(&self) -> &[Vec<bool>] {
+ &self.groups
+ }
+
+ /// Returns true if this `PhysicalGroupBy` has no group expressions
+ pub fn is_empty(&self) -> bool {
+ self.expr.is_empty()
+ }
+}
+
/// Hash aggregate execution plan
#[derive(Debug)]
pub struct AggregateExec {
/// Aggregation mode (full, partial)
mode: AggregateMode,
- /// Grouping expressions
- group_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+ /// Group by expressions
+ group_by: PhysicalGroupBy,
/// Aggregate expressions
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
/// Input plan, could be a partial aggregate or the input to the aggregate
@@ -90,18 +171,24 @@ impl AggregateExec {
/// Create a new hash aggregate execution plan
pub fn try_new(
mode: AggregateMode,
- group_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+ group_by: PhysicalGroupBy,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
- let schema = create_schema(&input.schema(), &group_expr, &aggr_expr, mode)?;
+ let schema = create_schema(
+ &input.schema(),
+ &group_by.expr,
+ &aggr_expr,
+ group_by.contains_null(),
+ mode,
+ )?;
let schema = Arc::new(schema);
Ok(AggregateExec {
mode,
- group_expr,
+ group_by,
aggr_expr,
input,
schema,
@@ -116,15 +203,16 @@ impl AggregateExec {
}
/// Grouping expressions
- pub fn group_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
- &self.group_expr
+ pub fn group_expr(&self) -> &PhysicalGroupBy {
+ &self.group_by
}
/// Grouping expressions as they occur in the output schema
pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
// Update column indices. Since the group by columns come first in the output schema, their
// indices are simply 0..self.group_expr(len).
- self.group_expr
+ self.group_by
+ .expr()
.iter()
.enumerate()
.map(|(index, (_col, name))| {
@@ -149,7 +237,7 @@ impl AggregateExec {
}
fn row_aggregate_supported(&self) -> bool {
- let group_schema = group_schema(&self.schema, self.group_expr.len());
+ let group_schema = group_schema(&self.schema, self.group_by.expr.len());
row_supported(&group_schema, RowType::Compact)
&& accumulator_v2_supported(&self.aggr_expr)
}
@@ -178,7 +266,7 @@ impl ExecutionPlan for AggregateExec {
match &self.mode {
AggregateMode::Partial => Distribution::UnspecifiedDistribution,
AggregateMode::FinalPartitioned => Distribution::HashPartitioned(
- self.group_expr.iter().map(|x| x.0.clone()).collect(),
+ self.group_by.expr.iter().map(|x| x.0.clone()).collect(),
),
AggregateMode::Final => Distribution::SinglePartition,
}
@@ -198,7 +286,7 @@ impl ExecutionPlan for AggregateExec {
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(AggregateExec::try_new(
self.mode,
- self.group_expr.clone(),
+ self.group_by.clone(),
self.aggr_expr.clone(),
children[0].clone(),
self.input_schema.clone(),
@@ -211,11 +299,10 @@ impl ExecutionPlan for AggregateExec {
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let input = self.input.execute(partition, context)?;
- let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect();
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
- if self.group_expr.is_empty() {
+ if self.group_by.expr.is_empty() {
Ok(Box::pin(AggregateStream::new(
self.mode,
self.schema.clone(),
@@ -227,7 +314,7 @@ impl ExecutionPlan for AggregateExec {
Ok(Box::pin(GroupedHashAggregateStreamV2::new(
self.mode,
self.schema.clone(),
- group_expr,
+ self.group_by.clone(),
self.aggr_expr.clone(),
input,
baseline_metrics,
@@ -236,7 +323,7 @@ impl ExecutionPlan for AggregateExec {
Ok(Box::pin(GroupedHashAggregateStream::new(
self.mode,
self.schema.clone(),
- group_expr,
+ self.group_by.clone(),
self.aggr_expr.clone(),
input,
baseline_metrics,
@@ -256,18 +343,53 @@ impl ExecutionPlan for AggregateExec {
match t {
DisplayFormatType::Default => {
write!(f, "AggregateExec: mode={:?}", self.mode)?;
- let g: Vec<String> = self
- .group_expr
- .iter()
- .map(|(e, alias)| {
- let e = e.to_string();
- if &e != alias {
- format!("{} as {}", e, alias)
- } else {
- e
- }
- })
- .collect();
+ let g: Vec<String> = if self.group_by.groups.len() == 1 {
+ self.group_by
+ .expr
+ .iter()
+ .map(|(e, alias)| {
+ let e = e.to_string();
+ if &e != alias {
+ format!("{} as {}", e, alias)
+ } else {
+ e
+ }
+ })
+ .collect()
+ } else {
+ self.group_by
+ .groups
+ .iter()
+ .map(|group| {
+ let terms = group
+ .iter()
+ .enumerate()
+ .map(|(idx, is_null)| {
+ if *is_null {
+ let (e, alias) = &self.group_by.null_expr[idx];
+ let e = e.to_string();
+ if &e != alias {
+ format!("{} as {}", e, alias)
+ } else {
+ e
+ }
+ } else {
+ let (e, alias) = &self.group_by.expr[idx];
+ let e = e.to_string();
+ if &e != alias {
+ format!("{} as {}", e, alias)
+ } else {
+ e
+ }
+ }
+ })
+ .collect::<Vec<String>>()
+ .join(", ");
+ format!("({})", terms)
+ })
+ .collect()
+ };
+
write!(f, ", gby=[{}]", g.join(", "))?;
let a: Vec<String> = self
@@ -289,7 +411,7 @@ impl ExecutionPlan for AggregateExec {
// - aggregations somtimes also preserve invariants such as min, max...
match self.mode {
AggregateMode::Final | AggregateMode::FinalPartitioned
- if self.group_expr.is_empty() =>
+ if self.group_by.expr.is_empty() =>
{
Statistics {
num_rows: Some(1),
@@ -306,6 +428,7 @@ fn create_schema(
input_schema: &Schema,
group_expr: &[(Arc<dyn PhysicalExpr>, String)],
aggr_expr: &[Arc<dyn AggregateExpr>],
+ contains_null_expr: bool,
mode: AggregateMode,
) -> datafusion_common::Result<Schema> {
let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
@@ -313,7 +436,10 @@ fn create_schema(
fields.push(Field::new(
name,
expr.data_type(input_schema)?,
- expr.nullable(input_schema)?,
+ // In cases where we have multiple grouping sets, we will use NULL expressions in
+ // order to align the grouping sets. So the field must be nullable even if the underlying
+ // schema field is not.
+ contains_null_expr || expr.nullable(input_schema)?,
))
}
@@ -469,11 +595,54 @@ fn evaluate_many(
.collect::<Result<Vec<_>>>()
}
+fn evaluate_group_by(
+ group_by: &PhysicalGroupBy,
+ batch: &RecordBatch,
+) -> Result<Vec<Vec<ArrayRef>>> {
+ let exprs: Vec<ArrayRef> = group_by
+ .expr
+ .iter()
+ .map(|(expr, _)| {
+ let value = expr.evaluate(batch)?;
+ Ok(value.into_array(batch.num_rows()))
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let null_exprs: Vec<ArrayRef> = group_by
+ .null_expr
+ .iter()
+ .map(|(expr, _)| {
+ let value = expr.evaluate(batch)?;
+ Ok(value.into_array(batch.num_rows()))
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(group_by
+ .groups
+ .iter()
+ .map(|group| {
+ group
+ .iter()
+ .enumerate()
+ .map(|(idx, is_null)| {
+ if *is_null {
+ null_exprs[idx].clone()
+ } else {
+ exprs[idx].clone()
+ }
+ })
+ .collect()
+ })
+ .collect())
+}
+
#[cfg(test)]
mod tests {
use crate::execution::context::TaskContext;
use crate::from_slice::FromSlice;
- use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
+ use crate::physical_plan::aggregates::{
+ AggregateExec, AggregateMode, PhysicalGroupBy,
+ };
use crate::physical_plan::expressions::{col, Avg};
use crate::test::assert_is_pending;
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
@@ -482,7 +651,8 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
- use datafusion_common::{DataFusionError, Result};
+ use datafusion_common::{DataFusionError, Result, ScalarValue};
+ use datafusion_physical_expr::expressions::{lit, Count};
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
use futures::{FutureExt, Stream};
use std::any::Any;
@@ -528,12 +698,129 @@ mod tests {
)
}
+ async fn check_grouping_sets(input: Arc<dyn ExecutionPlan>) -> Result<()> {
+ let input_schema = input.schema();
+
+ let grouping_set = PhysicalGroupBy {
+ expr: vec![
+ (col("a", &input_schema)?, "a".to_string()),
+ (col("b", &input_schema)?, "b".to_string()),
+ ],
+ null_expr: vec![
+ (lit(ScalarValue::UInt32(None)), "a".to_string()),
+ (lit(ScalarValue::Float64(None)), "b".to_string()),
+ ],
+ groups: vec![
+ vec![false, true], // (a, NULL)
+ vec![true, false], // (NULL, b)
+ vec![false, false], // (a,b)
+ ],
+ };
+
+ let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Count::new(
+ lit(ScalarValue::Int8(Some(1))),
+ "COUNT(1)".to_string(),
+ DataType::Int64,
+ ))];
+
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+
+ let partial_aggregate = Arc::new(AggregateExec::try_new(
+ AggregateMode::Partial,
+ grouping_set.clone(),
+ aggregates.clone(),
+ input,
+ input_schema.clone(),
+ )?);
+
+ let result =
+ common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;
+
+ let expected = vec![
+ "+---+---+-----------------+",
+ "| a | b | COUNT(1)[count] |",
+ "+---+---+-----------------+",
+ "| | 1 | 2 |",
+ "| | 2 | 2 |",
+ "| | 3 | 2 |",
+ "| | 4 | 2 |",
+ "| 2 | | 2 |",
+ "| 2 | 1 | 2 |",
+ "| 3 | | 3 |",
+ "| 3 | 2 | 2 |",
+ "| 3 | 3 | 1 |",
+ "| 4 | | 3 |",
+ "| 4 | 3 | 1 |",
+ "| 4 | 4 | 2 |",
+ "+---+---+-----------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ let groups = partial_aggregate.group_expr().expr().to_vec();
+
+ let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
+
+ let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = groups
+ .iter()
+ .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
+ .collect::<Result<_>>()?;
+
+ let final_grouping_set = PhysicalGroupBy::new_single(final_group);
+
+ let merged_aggregate = Arc::new(AggregateExec::try_new(
+ AggregateMode::Final,
+ final_grouping_set,
+ aggregates,
+ merge,
+ input_schema,
+ )?);
+
+ let result =
+ common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?;
+ assert_eq!(result.len(), 1);
+
+ let batch = &result[0];
+ assert_eq!(batch.num_columns(), 3);
+ assert_eq!(batch.num_rows(), 12);
+
+ let expected = vec![
+ "+---+---+----------+",
+ "| a | b | COUNT(1) |",
+ "+---+---+----------+",
+ "| | 1 | 2 |",
+ "| | 2 | 2 |",
+ "| | 3 | 2 |",
+ "| | 4 | 2 |",
+ "| 2 | | 2 |",
+ "| 2 | 1 | 2 |",
+ "| 3 | | 3 |",
+ "| 3 | 2 | 2 |",
+ "| 3 | 3 | 1 |",
+ "| 4 | | 3 |",
+ "| 4 | 3 | 1 |",
+ "| 4 | 4 | 2 |",
+ "+---+---+----------+",
+ ];
+
+ assert_batches_sorted_eq!(&expected, &result);
+
+ let metrics = merged_aggregate.metrics().unwrap();
+ let output_rows = metrics.output_rows().unwrap();
+ assert_eq!(12, output_rows);
+
+ Ok(())
+ }
+
/// build the aggregates on the data from some_data() and check the results
async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> {
let input_schema = input.schema();
- let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
- vec![(col("a", &input_schema)?, "a".to_string())];
+ let grouping_set = PhysicalGroupBy {
+ expr: vec![(col("a", &input_schema)?, "a".to_string())],
+ null_expr: vec![],
+ groups: vec![vec![false]],
+ };
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &input_schema)?,
@@ -546,7 +833,7 @@ mod tests {
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
- groups.clone(),
+ grouping_set.clone(),
aggregates.clone(),
input,
input_schema.clone(),
@@ -568,17 +855,17 @@ mod tests {
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
- let final_group: Vec<Arc<dyn PhysicalExpr>> = (0..groups.len())
- .map(|i| col(&groups[i].1, &input_schema))
+ let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = grouping_set
+ .expr
+ .iter()
+ .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
.collect::<Result<_>>()?;
+ let final_grouping_set = PhysicalGroupBy::new_single(final_group);
+
let merged_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
- final_group
- .iter()
- .enumerate()
- .map(|(i, expr)| (expr.clone(), groups[i].1.clone()))
- .collect(),
+ final_grouping_set,
aggregates,
merge,
input_schema,
@@ -719,6 +1006,14 @@ mod tests {
check_aggregates(input).await
}
+ #[tokio::test]
+ async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
+ let input: Arc<dyn ExecutionPlan> =
+ Arc::new(TestYieldingExec { yield_first: false });
+
+ check_grouping_sets(input).await
+ }
+
#[tokio::test]
async fn aggregate_source_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
@@ -727,6 +1022,14 @@ mod tests {
check_aggregates(input).await
}
+ #[tokio::test]
+ async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
+ let input: Arc<dyn ExecutionPlan> =
+ Arc::new(TestYieldingExec { yield_first: true });
+
+ check_grouping_sets(input).await
+ }
+
#[tokio::test]
async fn test_drop_cancel_without_groups() -> Result<()> {
let session_ctx = SessionContext::new();
@@ -734,7 +1037,7 @@ mod tests {
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
- let groups = vec![];
+ let groups = PhysicalGroupBy::default();
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("a", &schema)?,
@@ -771,8 +1074,8 @@ mod tests {
Field::new("b", DataType::Float32, true),
]));
- let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
- vec![(col("a", &schema)?, "a".to_string())];
+ let groups =
+ PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &schema)?,
@@ -784,7 +1087,7 @@ mod tests {
let refs = blocking_exec.refs();
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
- groups.clone(),
+ groups,
aggregates.clone(),
blocking_exec,
schema,
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index e364048e7..5353bc745 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -29,7 +29,8 @@ use futures::{
use crate::error::Result;
use crate::physical_plan::aggregates::{
- evaluate, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode,
+ evaluate_group_by, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode,
+ PhysicalGroupBy,
};
use crate::physical_plan::hash_utils::create_row_hashes;
use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
@@ -75,7 +76,7 @@ pub(crate) struct GroupedHashAggregateStreamV2 {
aggr_state: AggregationState,
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
- group_expr: Vec<Arc<dyn PhysicalExpr>>,
+ group_by: PhysicalGroupBy,
accumulators: Vec<AccumulatorItemV2>,
group_schema: SchemaRef,
@@ -100,7 +101,7 @@ impl GroupedHashAggregateStreamV2 {
pub fn new(
mode: AggregateMode,
schema: SchemaRef,
- group_expr: Vec<Arc<dyn PhysicalExpr>>,
+ group_by: PhysicalGroupBy,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
@@ -111,11 +112,11 @@ impl GroupedHashAggregateStreamV2 {
// Assume create_schema() always put group columns in front of aggr columns, we set
// col_idx_base to group expression count.
let aggregate_expressions =
- aggregates::aggregate_expressions(&aggr_expr, &mode, group_expr.len())?;
+ aggregates::aggregate_expressions(&aggr_expr, &mode, group_by.expr.len())?;
let accumulators = aggregates::create_accumulators_v2(&aggr_expr)?;
- let group_schema = group_schema(&schema, group_expr.len());
+ let group_schema = group_schema(&schema, group_by.expr.len());
let aggr_schema = aggr_state_schema(&aggr_expr)?;
let aggr_layout = Arc::new(RowLayout::new(&aggr_schema, RowType::WordAligned));
@@ -125,7 +126,7 @@ impl GroupedHashAggregateStreamV2 {
schema,
mode,
input,
- group_expr,
+ group_by,
accumulators,
group_schema,
aggr_schema,
@@ -160,7 +161,7 @@ impl Stream for GroupedHashAggregateStreamV2 {
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
- &this.group_expr,
+ &this.group_by,
&mut this.accumulators,
&this.group_schema,
this.aggr_layout.clone(),
@@ -212,7 +213,7 @@ impl RecordBatchStream for GroupedHashAggregateStreamV2 {
fn group_aggregate_batch(
mode: &AggregateMode,
random_state: &RandomState,
- group_expr: &[Arc<dyn PhysicalExpr>],
+ grouping_set: &PhysicalGroupBy,
accumulators: &mut [AccumulatorItemV2],
group_schema: &Schema,
state_layout: Arc<RowLayout>,
@@ -221,142 +222,145 @@ fn group_aggregate_batch(
aggregate_expressions: &[Vec<Arc<dyn PhysicalExpr>>],
) -> Result<()> {
// evaluate the grouping expressions
- let group_values = evaluate(group_expr, &batch)?;
- let group_rows: Vec<Vec<u8>> = create_group_rows(group_values, group_schema);
-
- // evaluate the aggregation expressions.
- // We could evaluate them after the `take`, but since we need to evaluate all
- // of them anyways, it is more performant to do it while they are together.
- let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?;
-
- // 1.1 construct the key from the group values
- // 1.2 construct the mapping key if it does not exist
- // 1.3 add the row' index to `indices`
-
- // track which entries in `aggr_state` have rows in this batch to aggregate
- let mut groups_with_rows = vec![];
-
- // 1.1 Calculate the group keys for the group values
- let mut batch_hashes = vec![0; batch.num_rows()];
- create_row_hashes(&group_rows, random_state, &mut batch_hashes)?;
-
- for (row, hash) in batch_hashes.into_iter().enumerate() {
- let AggregationState { map, group_states } = aggr_state;
-
- let entry = map.get_mut(hash, |(_hash, group_idx)| {
- // verify that a group that we are inserting with hash is
- // actually the same key value as the group in
- // existing_idx (aka group_values @ row)
- let group_state = &group_states[*group_idx];
- group_rows[row] == group_state.group_by_values
- });
-
- match entry {
- // Existing entry for this group value
- Some((_hash, group_idx)) => {
- let group_state = &mut group_states[*group_idx];
- // 1.3
- if group_state.indices.is_empty() {
- groups_with_rows.push(*group_idx);
- };
- group_state.indices.push(row as u32); // remember this row
- }
- // 1.2 Need to create new entry
- None => {
- // Add new entry to group_states and save newly created index
- let group_state = RowGroupState {
- group_by_values: group_rows[row].clone(),
- aggregation_buffer: vec![0; state_layout.fixed_part_width()],
- indices: vec![row as u32], // 1.3
- };
- let group_idx = group_states.len();
- group_states.push(group_state);
- groups_with_rows.push(group_idx);
-
- // for hasher function, use precomputed hash value
- map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
- }
- };
- }
-
- // Collect all indices + offsets based on keys in this vec
- let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
- let mut offsets = vec![0];
- let mut offset_so_far = 0;
- for group_idx in groups_with_rows.iter() {
- let indices = &aggr_state.group_states[*group_idx].indices;
- batch_indices.append_slice(indices)?;
- offset_so_far += indices.len();
- offsets.push(offset_so_far);
- }
- let batch_indices = batch_indices.finish();
+ let grouping_by_values = evaluate_group_by(grouping_set, &batch)?;
+
+ for group_values in grouping_by_values {
+ let group_rows: Vec<Vec<u8>> = create_group_rows(group_values, group_schema);
+
+ // evaluate the aggregation expressions.
+ // We could evaluate them after the `take`, but since we need to evaluate all
+ // of them anyways, it is more performant to do it while they are together.
+ let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?;
+
+ // 1.1 construct the key from the group values
+ // 1.2 construct the mapping key if it does not exist
+ // 1.3 add the row' index to `indices`
+
+ // track which entries in `aggr_state` have rows in this batch to aggregate
+ let mut groups_with_rows = vec![];
+
+ // 1.1 Calculate the group keys for the group values
+ let mut batch_hashes = vec![0; batch.num_rows()];
+ create_row_hashes(&group_rows, random_state, &mut batch_hashes)?;
+
+ for (row, hash) in batch_hashes.into_iter().enumerate() {
+ let AggregationState { map, group_states } = aggr_state;
+
+ let entry = map.get_mut(hash, |(_hash, group_idx)| {
+ // verify that a group that we are inserting with hash is
+ // actually the same key value as the group in
+ // existing_idx (aka group_values @ row)
+ let group_state = &group_states[*group_idx];
+ group_rows[row] == group_state.group_by_values
+ });
+
+ match entry {
+ // Existing entry for this group value
+ Some((_hash, group_idx)) => {
+ let group_state = &mut group_states[*group_idx];
+ // 1.3
+ if group_state.indices.is_empty() {
+ groups_with_rows.push(*group_idx);
+ };
+ group_state.indices.push(row as u32); // remember this row
+ }
+ // 1.2 Need to create new entry
+ None => {
+ // Add new entry to group_states and save newly created index
+ let group_state = RowGroupState {
+ group_by_values: group_rows[row].clone(),
+ aggregation_buffer: vec![0; state_layout.fixed_part_width()],
+ indices: vec![row as u32], // 1.3
+ };
+ let group_idx = group_states.len();
+ group_states.push(group_state);
+ groups_with_rows.push(group_idx);
+
+ // for hasher function, use precomputed hash value
+ map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
+ }
+ };
+ }
- // `Take` all values based on indices into Arrays
- let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
- .iter()
- .map(|array| {
- array
- .iter()
- .map(|array| {
- compute::take(
- array.as_ref(),
- &batch_indices,
- None, // None: no index check
- )
- .unwrap()
- })
- .collect()
- // 2.3
- })
- .collect();
-
- // 2.1 for each key in this batch
- // 2.2 for each aggregation
- // 2.3 `slice` from each of its arrays the keys' values
- // 2.4 update / merge the accumulator with the values
- // 2.5 clear indices
- groups_with_rows
- .iter()
- .zip(offsets.windows(2))
- .try_for_each(|(group_idx, offsets)| {
- let group_state = &mut aggr_state.group_states[*group_idx];
- // 2.2
- accumulators
- .iter_mut()
- .zip(values.iter())
- .map(|(accumulator, aggr_array)| {
- (
- accumulator,
- aggr_array
- .iter()
- .map(|array| {
- // 2.3
- array.slice(offsets[0], offsets[1] - offsets[0])
- })
- .collect::<Vec<ArrayRef>>(),
- )
- })
- .try_for_each(|(accumulator, values)| {
- let mut state_accessor =
- RowAccessor::new_from_layout(state_layout.clone());
- state_accessor
- .point_to(0, group_state.aggregation_buffer.as_mut_slice());
- match mode {
- AggregateMode::Partial => {
- accumulator.update_batch(&values, &mut state_accessor)
- }
- AggregateMode::FinalPartitioned | AggregateMode::Final => {
- // note: the aggregation here is over states, not values, thus the merge
- accumulator.merge_batch(&values, &mut state_accessor)
+ // Collect all indices + offsets based on keys in this vec
+ let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
+ let mut offsets = vec![0];
+ let mut offset_so_far = 0;
+ for group_idx in groups_with_rows.iter() {
+ let indices = &aggr_state.group_states[*group_idx].indices;
+ batch_indices.append_slice(indices)?;
+ offset_so_far += indices.len();
+ offsets.push(offset_so_far);
+ }
+ let batch_indices = batch_indices.finish();
+
+ // `Take` all values based on indices into Arrays
+ let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
+ .iter()
+ .map(|array| {
+ array
+ .iter()
+ .map(|array| {
+ compute::take(
+ array.as_ref(),
+ &batch_indices,
+ None, // None: no index check
+ )
+ .unwrap()
+ })
+ .collect()
+ // 2.3
+ })
+ .collect();
+
+ // 2.1 for each key in this batch
+ // 2.2 for each aggregation
+ // 2.3 `slice` from each of its arrays the keys' values
+ // 2.4 update / merge the accumulator with the values
+ // 2.5 clear indices
+ groups_with_rows
+ .iter()
+ .zip(offsets.windows(2))
+ .try_for_each(|(group_idx, offsets)| {
+ let group_state = &mut aggr_state.group_states[*group_idx];
+ // 2.2
+ accumulators
+ .iter_mut()
+ .zip(values.iter())
+ .map(|(accumulator, aggr_array)| {
+ (
+ accumulator,
+ aggr_array
+ .iter()
+ .map(|array| {
+ // 2.3
+ array.slice(offsets[0], offsets[1] - offsets[0])
+ })
+ .collect::<Vec<ArrayRef>>(),
+ )
+ })
+ .try_for_each(|(accumulator, values)| {
+ let mut state_accessor =
+ RowAccessor::new_from_layout(state_layout.clone());
+ state_accessor
+ .point_to(0, group_state.aggregation_buffer.as_mut_slice());
+ match mode {
+ AggregateMode::Partial => {
+ accumulator.update_batch(&values, &mut state_accessor)
+ }
+ AggregateMode::FinalPartitioned | AggregateMode::Final => {
+ // note: the aggregation here is over states, not values, thus the merge
+ accumulator.merge_batch(&values, &mut state_accessor)
+ }
}
- }
- })
- // 2.5
- .and({
- group_state.indices.clear();
- Ok(())
- })
- })?;
+ })
+ // 2.5
+ .and({
+ group_state.indices.clear();
+ Ok(())
+ })
+ })?;
+ }
Ok(())
}
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 60cc3b8de..14cdee301 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -37,7 +37,7 @@ use crate::logical_plan::{
use crate::logical_plan::{Limit, Values};
use crate::physical_expr::create_physical_expr;
use crate::physical_optimizer::optimizer::PhysicalOptimizerRule;
-use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
+use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use crate::physical_plan::cross_join::CrossJoinExec;
use crate::physical_plan::explain::ExplainExec;
use crate::physical_plan::expressions::{Column, PhysicalSortExpr};
@@ -58,10 +58,13 @@ use arrow::compute::SortOptions;
use arrow::datatypes::DataType;
use arrow::datatypes::{Schema, SchemaRef};
use async_trait::async_trait;
+use datafusion_common::ScalarValue;
use datafusion_expr::{expr::GroupingSet, utils::expr_to_columns};
+use datafusion_physical_expr::expressions::Literal;
use datafusion_sql::utils::window_expr_common_partition_keys;
use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
+use itertools::Itertools;
use log::{debug, trace};
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
@@ -535,20 +538,12 @@ impl DefaultPhysicalPlanner {
let physical_input_schema = input_exec.schema();
let logical_input_schema = input.as_ref().schema();
- let groups = group_expr
- .iter()
- .map(|e| {
- tuple_err((
- self.create_physical_expr(
- e,
- logical_input_schema,
- &physical_input_schema,
- session_state,
- ),
- physical_name(e),
- ))
- })
- .collect::<Result<Vec<_>>>()?;
+ let groups = self.create_grouping_physical_expr(
+ group_expr,
+ logical_input_schema,
+ &physical_input_schema,
+ session_state)?;
+
let aggregates = aggr_expr
.iter()
.map(|e| {
@@ -574,6 +569,7 @@ impl DefaultPhysicalPlanner {
// TODO: dictionary type not yet supported in Hash Repartition
let contains_dict = groups
+ .expr()
.iter()
.flat_map(|x| x.0.data_type(physical_input_schema.as_ref()))
.any(|x| matches!(x, DataType::Dictionary(_, _)));
@@ -603,13 +599,17 @@ impl DefaultPhysicalPlanner {
(initial_aggr, AggregateMode::Final)
};
- Ok(Arc::new(AggregateExec::try_new(
- next_partition_mode,
+ let final_grouping_set = PhysicalGroupBy::new_single(
final_group
.iter()
.enumerate()
- .map(|(i, expr)| (expr.clone(), groups[i].1.clone()))
- .collect(),
+ .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone()))
+ .collect()
+ );
+
+ Ok(Arc::new(AggregateExec::try_new(
+ next_partition_mode,
+ final_grouping_set,
aggregates,
initial_aggr,
physical_input_schema.clone(),
@@ -1001,6 +1001,261 @@ impl DefaultPhysicalPlanner {
exec_plan
}.boxed()
}
+
+ fn create_grouping_physical_expr(
+ &self,
+ group_expr: &[Expr],
+ input_dfschema: &DFSchema,
+ input_schema: &Schema,
+ session_state: &SessionState,
+ ) -> Result<PhysicalGroupBy> {
+ if group_expr.len() == 1 {
+ match &group_expr[0] {
+ Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
+ merge_grouping_set_physical_expr(
+ grouping_sets,
+ input_dfschema,
+ input_schema,
+ session_state,
+ )
+ }
+ Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr(
+ exprs,
+ input_dfschema,
+ input_schema,
+ session_state,
+ ),
+ Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
+ create_rollup_physical_expr(
+ exprs,
+ input_dfschema,
+ input_schema,
+ session_state,
+ )
+ }
+ expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err((
+ self.create_physical_expr(
+ expr,
+ input_dfschema,
+ input_schema,
+ session_state,
+ ),
+ physical_name(expr),
+ ))?])),
+ }
+ } else {
+ Ok(PhysicalGroupBy::new_single(
+ group_expr
+ .iter()
+ .map(|e| {
+ tuple_err((
+ self.create_physical_expr(
+ e,
+ input_dfschema,
+ input_schema,
+ session_state,
+ ),
+ physical_name(e),
+ ))
+ })
+ .collect::<Result<Vec<_>>>()?,
+ ))
+ }
+ }
+}
+
+/// Expand and align a GROUPING SET expression.
+/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS)
+///
+/// This will take a list of grouping sets and ensure that each group is
+/// properly aligned for the physical execution plan. We do this by
+/// identifying all unique expression in each group and conforming each
+/// group to the same set of expression types and ordering.
+/// For example, if we have something like `GROUPING SETS ((a,b,c),(a),(b),(b,c))`
+/// we would expand this to `GROUPING SETS ((a,b,c),(a,NULL,NULL),(NULL,b,NULL),(NULL,b,c))
+/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS)
+fn merge_grouping_set_physical_expr(
+ grouping_sets: &[Vec<Expr>],
+ input_dfschema: &DFSchema,
+ input_schema: &Schema,
+ session_state: &SessionState,
+) -> Result<PhysicalGroupBy> {
+ let num_groups = grouping_sets.len();
+ let mut all_exprs: Vec<Expr> = vec![];
+ let mut grouping_set_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
+ let mut null_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
+
+ for expr in grouping_sets.iter().flatten() {
+ if !all_exprs.contains(expr) {
+ all_exprs.push(expr.clone());
+
+ grouping_set_expr.push(get_physical_expr_pair(
+ expr,
+ input_dfschema,
+ input_schema,
+ session_state,
+ )?);
+
+ null_exprs.push(get_null_physical_expr_pair(
+ expr,
+ input_dfschema,
+ input_schema,
+ session_state,
+ )?);
+ }
+ }
+
+ let mut merged_sets: Vec<Vec<bool>> = Vec::with_capacity(num_groups);
+
+ for expr_group in grouping_sets.iter() {
+ let group: Vec<bool> = all_exprs
+ .iter()
+ .map(|expr| !expr_group.contains(expr))
+ .collect();
+
+ merged_sets.push(group)
+ }
+
+ Ok(PhysicalGroupBy::new(
+ grouping_set_expr,
+ null_exprs,
+ merged_sets,
+ ))
+}
+
+/// Expand and align a CUBE expression. This is a special case of GROUPING SETS
+/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS)
+fn create_cube_physical_expr(
+ exprs: &[Expr],
+ input_dfschema: &DFSchema,
+ input_schema: &Schema,
+ session_state: &SessionState,
+) -> Result<PhysicalGroupBy> {
+ let num_of_exprs = exprs.len();
+ let num_groups = num_of_exprs * num_of_exprs;
+
+ let mut null_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
+ Vec::with_capacity(num_of_exprs);
+ let mut all_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
+ Vec::with_capacity(num_of_exprs);
+
+ for expr in exprs {
+ null_exprs.push(get_null_physical_expr_pair(
+ expr,
+ input_dfschema,
+ input_schema,
+ session_state,
+ )?);
+
+ all_exprs.push(get_physical_expr_pair(
+ expr,
+ input_dfschema,
+ input_schema,
+ session_state,
+ )?)
+ }
+
+ let mut groups: Vec<Vec<bool>> = Vec::with_capacity(num_groups);
+
+ groups.push(vec![false; num_of_exprs]);
+
+ for null_count in 1..=num_of_exprs {
+ for null_idx in (0..num_of_exprs).combinations(null_count) {
+ let mut next_group: Vec<bool> = vec![false; num_of_exprs];
+ null_idx.into_iter().for_each(|i| next_group[i] = true);
+ groups.push(next_group);
+ }
+ }
+
+ Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups))
+}
+
+/// Expand and align a ROLLUP expression. This is a special case of GROUPING SETS
+/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS)
+fn create_rollup_physical_expr(
+ exprs: &[Expr],
+ input_dfschema: &DFSchema,
+ input_schema: &Schema,
+ session_state: &SessionState,
+) -> Result<PhysicalGroupBy> {
+ let num_of_exprs = exprs.len();
+
+ let mut null_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
+ Vec::with_capacity(num_of_exprs);
+ let mut all_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
+ Vec::with_capacity(num_of_exprs);
+
+ let mut groups: Vec<Vec<bool>> = Vec::with_capacity(num_of_exprs + 1);
+
+ for expr in exprs {
+ null_exprs.push(get_null_physical_expr_pair(
+ expr,
+ input_dfschema,
+ input_schema,
+ session_state,
+ )?);
+
+ all_exprs.push(get_physical_expr_pair(
+ expr,
+ input_dfschema,
+ input_schema,
+ session_state,
+ )?)
+ }
+
+ for total in 0..=num_of_exprs {
+ let mut group: Vec<bool> = Vec::with_capacity(num_of_exprs);
+
+ for index in 0..num_of_exprs {
+ if index < total {
+ group.push(false);
+ } else {
+ group.push(true);
+ }
+ }
+
+ groups.push(group)
+ }
+
+ Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups))
+}
+
+/// For a given logical expr, get a properly typed NULL ScalarValue physical expression
+fn get_null_physical_expr_pair(
+ expr: &Expr,
+ input_dfschema: &DFSchema,
+ input_schema: &Schema,
+ session_state: &SessionState,
+) -> Result<(Arc<dyn PhysicalExpr>, String)> {
+ let physical_expr = create_physical_expr(
+ expr,
+ input_dfschema,
+ input_schema,
+ &session_state.execution_props,
+ )?;
+ let physical_name = physical_name(&expr.clone())?;
+
+ let data_type = physical_expr.data_type(input_schema)?;
+ let null_value: ScalarValue = (&data_type).try_into()?;
+
+ let null_value = Literal::new(null_value);
+ Ok((Arc::new(null_value), physical_name))
+}
+
+fn get_physical_expr_pair(
+ expr: &Expr,
+ input_dfschema: &DFSchema,
+ input_schema: &Schema,
+ session_state: &SessionState,
+) -> Result<(Arc<dyn PhysicalExpr>, String)> {
+ let physical_expr = create_physical_expr(
+ expr,
+ input_dfschema,
+ input_schema,
+ &session_state.execution_props,
+ )?;
+ let physical_name = physical_name(expr)?;
+ Ok((physical_expr, physical_name))
}
/// Create a window expression with a name from a logical expression
@@ -1303,6 +1558,7 @@ mod tests {
};
use arrow::datatypes::{DataType, Field, SchemaRef};
use datafusion_common::{DFField, DFSchema, DFSchemaRef};
+ use datafusion_expr::expr::GroupingSet;
use datafusion_expr::sum;
use datafusion_expr::{col, lit};
use fmt::Debug;
@@ -1346,6 +1602,60 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn test_create_cube_expr() -> Result<()> {
+ let logical_plan = test_csv_scan().await?.build()?;
+
+ let plan = plan(&logical_plan).await?;
+
+ let exprs = vec![col("c1"), col("c2"), col("c3")];
+
+ let physical_input_schema = plan.schema();
+ let physical_input_schema = physical_input_schema.as_ref();
+ let logical_input_schema = logical_plan.schema();
+ let session_state = make_session_state();
+
+ let cube = create_cube_physical_expr(
+ &exprs,
+ logical_input_schema,
+ physical_input_schema,
+ &session_state,
+ );
+
+ let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;
+
+ assert_eq!(format!("{:?}", cube), expected);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_create_rollup_expr() -> Result<()> {
+ let logical_plan = test_csv_scan().await?.build()?;
+
+ let plan = plan(&logical_plan).await?;
+
+ let exprs = vec![col("c1"), col("c2"), col("c3")];
+
+ let physical_input_schema = plan.schema();
+ let physical_input_schema = physical_input_schema.as_ref();
+ let logical_input_schema = logical_plan.schema();
+ let session_state = make_session_state();
+
+ let rollup = create_rollup_physical_expr(
+ &exprs,
+ logical_input_schema,
+ physical_input_schema,
+ &session_state,
+ );
+
+ let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;
+
+ assert_eq!(format!("{:?}", rollup), expected);
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_create_not() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
@@ -1620,6 +1930,34 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn hash_agg_grouping_set_input_schema() -> Result<()> {
+ let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+ vec![col("c1")],
+ vec![col("c2")],
+ vec![col("c1"), col("c2")],
+ ]));
+ let logical_plan = test_csv_scan_with_name("aggregate_test_100")
+ .await?
+ .aggregate(vec![grouping_set_expr], vec![sum(col("c3"))])?
+ .build()?;
+
+ let execution_plan = plan(&logical_plan).await?;
+ let final_hash_agg = execution_plan
+ .as_any()
+ .downcast_ref::<AggregateExec>()
+ .expect("hash aggregate");
+ assert_eq!(
+ "SUM(aggregate_test_100.c3)",
+ final_hash_agg.schema().field(2).name()
+ );
+ // we need access to the input to the partial aggregate so that other projects can
+ // implement serde
+ assert_eq!("c3", final_hash_agg.input_schema().field(2).name());
+
+ Ok(())
+ }
+
#[tokio::test]
async fn hash_agg_group_by_partitioned() -> Result<()> {
let logical_plan = test_csv_scan()
@@ -1637,6 +1975,28 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn hash_agg_grouping_set_by_partitioned() -> Result<()> {
+ let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+ vec![col("c1")],
+ vec![col("c2")],
+ vec![col("c1"), col("c2")],
+ ]));
+ let logical_plan = test_csv_scan()
+ .await?
+ .aggregate(vec![grouping_set_expr], vec![sum(col("c3"))])?
+ .build()?;
+
+ let execution_plan = plan(&logical_plan).await?;
+ let formatted = format!("{:?}", execution_plan);
+
+ // Make sure the plan contains a FinalPartitioned, which means it will not use the Final
+ // mode in Aggregate (which is slower)
+ assert!(formatted.contains("FinalPartitioned"));
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_explain() {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs
index 38f54a2a7..b25e83cb7 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -24,11 +24,14 @@ use datafusion::from_slice::FromSlice;
use std::sync::Arc;
use datafusion::assert_batches_eq;
+use datafusion::dataframe::DataFrame;
use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion::logical_plan::{col, Expr};
+use datafusion::prelude::CsvReadOptions;
use datafusion::{datasource::MemTable, prelude::JoinType};
-use datafusion_expr::lit;
+use datafusion_expr::expr::GroupingSet;
+use datafusion_expr::{avg, count, lit, sum};
#[tokio::test]
async fn join() -> Result<()> {
@@ -207,3 +210,217 @@ async fn select_with_alias_overwrite() -> Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn test_grouping_sets() -> Result<()> {
+ let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+ vec![col("a")],
+ vec![col("b")],
+ vec![col("a"), col("b")],
+ ]));
+
+ let df = create_test_table()?
+ .aggregate(vec![grouping_set_expr], vec![count(col("a"))])?
+ .sort(vec![
+ Expr::Sort {
+ expr: Box::new(col("a")),
+ asc: false,
+ nulls_first: true,
+ },
+ Expr::Sort {
+ expr: Box::new(col("b")),
+ asc: false,
+ nulls_first: true,
+ },
+ ])?;
+
+ let results = df.collect().await?;
+
+ let expected = vec![
+ "+-----------+-----+---------------+",
+ "| a | b | COUNT(test.a) |",
+ "+-----------+-----+---------------+",
+ "| | 100 | 1 |",
+ "| | 10 | 2 |",
+ "| | 1 | 1 |",
+ "| abcDEF | | 1 |",
+ "| abcDEF | 1 | 1 |",
+ "| abc123 | | 1 |",
+ "| abc123 | 10 | 1 |",
+ "| CBAdef | | 1 |",
+ "| CBAdef | 10 | 1 |",
+ "| 123AbcDef | | 1 |",
+ "| 123AbcDef | 100 | 1 |",
+ "+-----------+-----+---------------+",
+ ];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_grouping_sets_count() -> Result<()> {
+ let ctx = SessionContext::new();
+
+ let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+ vec![col("c1")],
+ vec![col("c2")],
+ ]));
+
+ let df = aggregates_table(&ctx)
+ .await?
+ .aggregate(vec![grouping_set_expr], vec![count(lit(1))])?
+ .sort(vec![
+ Expr::Sort {
+ expr: Box::new(col("c1")),
+ asc: false,
+ nulls_first: true,
+ },
+ Expr::Sort {
+ expr: Box::new(col("c2")),
+ asc: false,
+ nulls_first: true,
+ },
+ ])?;
+
+ let results = df.collect().await?;
+
+ let expected = vec![
+ "+----+----+-----------------+",
+ "| c1 | c2 | COUNT(Int32(1)) |",
+ "+----+----+-----------------+",
+ "| | 5 | 14 |",
+ "| | 4 | 23 |",
+ "| | 3 | 19 |",
+ "| | 2 | 22 |",
+ "| | 1 | 22 |",
+ "| e | | 21 |",
+ "| d | | 18 |",
+ "| c | | 21 |",
+ "| b | | 19 |",
+ "| a | | 21 |",
+ "+----+----+-----------------+",
+ ];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_grouping_set_array_agg_with_overflow() -> Result<()> {
+ let ctx = SessionContext::new();
+
+ let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+ vec![col("c1")],
+ vec![col("c2")],
+ vec![col("c1"), col("c2")],
+ ]));
+
+ let df = aggregates_table(&ctx)
+ .await?
+ .aggregate(
+ vec![grouping_set_expr],
+ vec![
+ sum(col("c3")).alias("sum_c3"),
+ avg(col("c3")).alias("avg_c3"),
+ ],
+ )?
+ .sort(vec![
+ Expr::Sort {
+ expr: Box::new(col("c1")),
+ asc: false,
+ nulls_first: true,
+ },
+ Expr::Sort {
+ expr: Box::new(col("c2")),
+ asc: false,
+ nulls_first: true,
+ },
+ ])?;
+
+ let results = df.collect().await?;
+
+ let expected = vec![
+ "+----+----+--------+---------------------+",
+ "| c1 | c2 | sum_c3 | avg_c3 |",
+ "+----+----+--------+---------------------+",
+ "| | 5 | -194 | -13.857142857142858 |",
+ "| | 4 | 29 | 1.2608695652173914 |",
+ "| | 3 | 395 | 20.789473684210527 |",
+ "| | 2 | 184 | 8.363636363636363 |",
+ "| | 1 | 367 | 16.681818181818183 |",
+ "| e | | 847 | 40.333333333333336 |",
+ "| e | 5 | -22 | -11 |",
+ "| e | 4 | 261 | 37.285714285714285 |",
+ "| e | 3 | 192 | 48 |",
+ "| e | 2 | 189 | 37.8 |",
+ "| e | 1 | 227 | 75.66666666666667 |",
+ "| d | | 458 | 25.444444444444443 |",
+ "| d | 5 | -99 | -49.5 |",
+ "| d | 4 | 162 | 54 |",
+ "| d | 3 | 124 | 41.333333333333336 |",
+ "| d | 2 | 328 | 109.33333333333333 |",
+ "| d | 1 | -57 | -8.142857142857142 |",
+ "| c | | -28 | -1.3333333333333333 |",
+ "| c | 5 | 24 | 12 |",
+ "| c | 4 | -43 | -10.75 |",
+ "| c | 3 | 190 | 47.5 |",
+ "| c | 2 | -389 | -55.57142857142857 |",
+ "| c | 1 | 190 | 47.5 |",
+ "| b | | -111 | -5.842105263157895 |",
+ "| b | 5 | -1 | -0.2 |",
+ "| b | 4 | -223 | -44.6 |",
+ "| b | 3 | -84 | -42 |",
+ "| b | 2 | 102 | 25.5 |",
+ "| b | 1 | 95 | 31.666666666666668 |",
+ "| a | | -385 | -18.333333333333332 |",
+ "| a | 5 | -96 | -32 |",
+ "| a | 4 | -128 | -32 |",
+ "| a | 3 | -27 | -4.5 |",
+ "| a | 2 | -46 | -15.333333333333334 |",
+ "| a | 1 | -88 | -17.6 |",
+ "+----+----+--------+---------------------+",
+ ];
+ assert_batches_eq!(expected, &results);
+
+ Ok(())
+}
+
+fn create_test_table() -> Result<Arc<DataFrame>> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Utf8, false),
+ Field::new("b", DataType::Int32, false),
+ ]));
+
+ // define data.
+ let batch = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(StringArray::from_slice(&[
+ "abcDEF",
+ "abc123",
+ "CBAdef",
+ "123AbcDef",
+ ])),
+ Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])),
+ ],
+ )?;
+
+ let ctx = SessionContext::new();
+
+ let table = MemTable::try_new(schema, vec![vec![batch]])?;
+
+ ctx.register_table("test", Arc::new(table))?;
+
+ ctx.table("test")
+}
+
+async fn aggregates_table(ctx: &SessionContext) -> Result<Arc<DataFrame>> {
+ let testdata = datafusion::test_util::arrow_test_data();
+
+ ctx.read_csv(
+ format!("{}/csv/aggregate_test_100.csv", testdata),
+ CsvReadOptions::default(),
+ )
+ .await
+}
diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs
index 08ccbe453..61b1a1afa 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -476,6 +476,205 @@ async fn csv_query_approx_percentile_cont() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn csv_query_cube_avg() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv_by_sql(&ctx).await;
+
+ let sql = "SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+----+----------------------------+",
+ "| c1 | c2 | AVG(aggregate_test_100.c3) |",
+ "+----+----+----------------------------+",
+ "| a | 1 | -17.6 |",
+ "| a | 2 | -15.333333333333334 |",
+ "| a | 3 | -4.5 |",
+ "| a | 4 | -32 |",
+ "| a | 5 | -32 |",
+ "| a | | -18.333333333333332 |",
+ "| b | 1 | 31.666666666666668 |",
+ "| b | 2 | 25.5 |",
+ "| b | 3 | -42 |",
+ "| b | 4 | -44.6 |",
+ "| b | 5 | -0.2 |",
+ "| b | | -5.842105263157895 |",
+ "| c | 1 | 47.5 |",
+ "| c | 2 | -55.57142857142857 |",
+ "| c | 3 | 47.5 |",
+ "| c | 4 | -10.75 |",
+ "| c | 5 | 12 |",
+ "| c | | -1.3333333333333333 |",
+ "| d | 1 | -8.142857142857142 |",
+ "| d | 2 | 109.33333333333333 |",
+ "| d | 3 | 41.333333333333336 |",
+ "| d | 4 | 54 |",
+ "| d | 5 | -49.5 |",
+ "| d | | 25.444444444444443 |",
+ "| e | 1 | 75.66666666666667 |",
+ "| e | 2 | 37.8 |",
+ "| e | 3 | 48 |",
+ "| e | 4 | 37.285714285714285 |",
+ "| e | 5 | -11 |",
+ "| e | | 40.333333333333336 |",
+ "| | 1 | 16.681818181818183 |",
+ "| | 2 | 8.363636363636363 |",
+ "| | 3 | 20.789473684210527 |",
+ "| | 4 | 1.2608695652173914 |",
+ "| | 5 | -13.857142857142858 |",
+ "| | | 7.81 |",
+ "+----+----+----------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_rollup_avg() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv_by_sql(&ctx).await;
+
+ let sql = "SELECT c1, c2, c3, AVG(c4) FROM aggregate_test_100 GROUP BY ROLLUP (c1, c2, c3) ORDER BY c1, c2, c3";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+----+------+----------------------------+",
+ "| c1 | c2 | c3 | AVG(aggregate_test_100.c4) |",
+ "+----+----+------+----------------------------+",
+ "| a | 1 | -85 | -15154 |",
+ "| a | 1 | -56 | 8692 |",
+ "| a | 1 | -25 | 15295 |",
+ "| a | 1 | -5 | 12636 |",
+ "| a | 1 | 83 | -14704 |",
+ "| a | 1 | | 1353 |",
+ "| a | 2 | -48 | -18025 |",
+ "| a | 2 | -43 | 13080 |",
+ "| a | 2 | 45 | 15673 |",
+ "| a | 2 | | 3576 |",
+ "| a | 3 | -72 | -11122 |",
+ "| a | 3 | -12 | -9168 |",
+ "| a | 3 | 13 | 22338.5 |",
+ "| a | 3 | 14 | 28162 |",
+ "| a | 3 | 17 | -22796 |",
+ "| a | 3 | | 4958.833333333333 |",
+ "| a | 4 | -101 | 11640 |",
+ "| a | 4 | -54 | -2376 |",
+ "| a | 4 | -38 | 20744 |",
+ "| a | 4 | 65 | -28462 |",
+ "| a | 4 | | 386.5 |",
+ "| a | 5 | -101 | -12484 |",
+ "| a | 5 | -31 | -12907 |",
+ "| a | 5 | 36 | -16974 |",
+ "| a | 5 | | -14121.666666666666 |",
+ "| a | | | 306.04761904761904 |",
+ "| b | 1 | 12 | 7652 |",
+ "| b | 1 | 29 | -18218 |",
+ "| b | 1 | 54 | -18410 |",
+ "| b | 1 | | -9658.666666666666 |",
+ "| b | 2 | -60 | -21739 |",
+ "| b | 2 | 31 | 23127 |",
+ "| b | 2 | 63 | 21456 |",
+ "| b | 2 | 68 | 15874 |",
+ "| b | 2 | | 9679.5 |",
+ "| b | 3 | -101 | -13217 |",
+ "| b | 3 | 17 | 14457 |",
+ "| b | 3 | | 620 |",
+ "| b | 4 | -117 | 19316 |",
+ "| b | 4 | -111 | -1967 |",
+ "| b | 4 | -59 | 25286 |",
+ "| b | 4 | 17 | -28070 |",
+ "| b | 4 | 47 | 20690 |",
+ "| b | 4 | | 7051 |",
+ "| b | 5 | -82 | 22080 |",
+ "| b | 5 | -44 | 15788 |",
+ "| b | 5 | -5 | 24896 |",
+ "| b | 5 | 62 | 16337 |",
+ "| b | 5 | 68 | 21576 |",
+ "| b | 5 | | 20135.4 |",
+ "| b | | | 7732.315789473684 |",
+ "| c | 1 | -24 | -24085 |",
+ "| c | 1 | 41 | -4667 |",
+ "| c | 1 | 70 | 27752 |",
+ "| c | 1 | 103 | -22186 |",
+ "| c | 1 | | -5796.5 |",
+ "| c | 2 | -117 | -30187 |",
+ "| c | 2 | -107 | -2904 |",
+ "| c | 2 | -106 | -1114 |",
+ "| c | 2 | -60 | -16312 |",
+ "| c | 2 | -29 | 25305 |",
+ "| c | 2 | 1 | 18109 |",
+ "| c | 2 | 29 | -3855 |",
+ "| c | 2 | | -1565.4285714285713 |",
+ "| c | 3 | -2 | -18655 |",
+ "| c | 3 | 22 | 13741 |",
+ "| c | 3 | 73 | -9565 |",
+ "| c | 3 | 97 | 29106 |",
+ "| c | 3 | | 3656.75 |",
+ "| c | 4 | -90 | -2935 |",
+ "| c | 4 | -79 | 5281 |",
+ "| c | 4 | 3 | -30508 |",
+ "| c | 4 | 123 | 16620 |",
+ "| c | 4 | | -2885.5 |",
+ "| c | 5 | -94 | -15880 |",
+ "| c | 5 | 118 | 19208 |",
+ "| c | 5 | | 1664 |",
+ "| c | | | -1320.5238095238096 |",
+ "| d | 1 | -99 | 5613 |",
+ "| d | 1 | -98 | 13630 |",
+ "| d | 1 | -72 | 25590 |",
+ "| d | 1 | -8 | 27138 |",
+ "| d | 1 | 38 | 18384 |",
+ "| d | 1 | 57 | 28781 |",
+ "| d | 1 | 125 | 31106 |",
+ "| d | 1 | | 21463.14285714286 |",
+ "| d | 2 | 93 | -12642 |",
+ "| d | 2 | 113 | 3917 |",
+ "| d | 2 | 122 | 10130 |",
+ "| d | 2 | | 468.3333333333333 |",
+ "| d | 3 | -76 | 8809 |",
+ "| d | 3 | 77 | 15091 |",
+ "| d | 3 | 123 | 29533 |",
+ "| d | 3 | | 17811 |",
+ "| d | 4 | 5 | -7688 |",
+ "| d | 4 | 55 | -1471 |",
+ "| d | 4 | 102 | -24558 |",
+ "| d | 4 | | -11239 |",
+ "| d | 5 | -59 | 2045 |",
+ "| d | 5 | -40 | 22614 |",
+ "| d | 5 | | 12329.5 |",
+ "| d | | | 10890.111111111111 |",
+ "| e | 1 | 36 | -21481 |",
+ "| e | 1 | 71 | -5479 |",
+ "| e | 1 | 120 | 10837 |",
+ "| e | 1 | | -5374.333333333333 |",
+ "| e | 2 | -61 | -2888 |",
+ "| e | 2 | 49 | 24495 |",
+ "| e | 2 | 52 | 5666 |",
+ "| e | 2 | 97 | 18167 |",
+ "| e | 2 | | 10221.2 |",
+ "| e | 3 | -95 | 13611 |",
+ "| e | 3 | 71 | 194 |",
+ "| e | 3 | 104 | -25136 |",
+ "| e | 3 | 112 | -6823 |",
+ "| e | 3 | | -4538.5 |",
+ "| e | 4 | -56 | -31500 |",
+ "| e | 4 | -53 | 13788 |",
+ "| e | 4 | 30 | -16110 |",
+ "| e | 4 | 73 | -22501 |",
+ "| e | 4 | 74 | -12612 |",
+ "| e | 4 | 96 | -30336 |",
+ "| e | 4 | 97 | -13181 |",
+ "| e | 4 | | -16064.57142857143 |",
+ "| e | 5 | -86 | 32514 |",
+ "| e | 5 | 64 | -26526 |",
+ "| e | 5 | | 2994 |",
+ "| e | | | -4268.333333333333 |",
+ "| | | | 2319.97 |",
+ "+----+----+------+----------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
#[tokio::test]
async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> {
let ctx = SessionContext::new();
@@ -583,6 +782,200 @@ async fn csv_query_sum_crossjoin() {
assert_batches_eq!(expected, &actual);
}
+#[tokio::test]
+async fn csv_query_cube_sum_crossjoin() {
+ let ctx = SessionContext::new();
+ register_aggregate_csv_by_sql(&ctx).await;
+ let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY CUBE (a.c1, b.c1) ORDER BY a.c1, b.c1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+----+-----------+",
+ "| c1 | c1 | SUM(a.c2) |",
+ "+----+----+-----------+",
+ "| a | a | 1260 |",
+ "| a | b | 1140 |",
+ "| a | c | 1260 |",
+ "| a | d | 1080 |",
+ "| a | e | 1260 |",
+ "| a | | 6000 |",
+ "| b | a | 1302 |",
+ "| b | b | 1178 |",
+ "| b | c | 1302 |",
+ "| b | d | 1116 |",
+ "| b | e | 1302 |",
+ "| b | | 6200 |",
+ "| c | a | 1176 |",
+ "| c | b | 1064 |",
+ "| c | c | 1176 |",
+ "| c | d | 1008 |",
+ "| c | e | 1176 |",
+ "| c | | 5600 |",
+ "| d | a | 924 |",
+ "| d | b | 836 |",
+ "| d | c | 924 |",
+ "| d | d | 792 |",
+ "| d | e | 924 |",
+ "| d | | 4400 |",
+ "| e | a | 1323 |",
+ "| e | b | 1197 |",
+ "| e | c | 1323 |",
+ "| e | d | 1134 |",
+ "| e | e | 1323 |",
+ "| e | | 6300 |",
+ "| | a | 5985 |",
+ "| | b | 5415 |",
+ "| | c | 5985 |",
+ "| | d | 5130 |",
+ "| | e | 5985 |",
+ "| | | 28500 |",
+ "+----+----+-----------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
+async fn csv_query_cube_distinct_count() {
+ let ctx = SessionContext::new();
+ register_aggregate_csv_by_sql(&ctx).await;
+ let sql = "SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY CUBE (c1,c2) ORDER BY c1,c2";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+----+---------------------------------------+",
+ "| c1 | c2 | COUNT(DISTINCT aggregate_test_100.c3) |",
+ "+----+----+---------------------------------------+",
+ "| a | 1 | 5 |",
+ "| a | 2 | 3 |",
+ "| a | 3 | 5 |",
+ "| a | 4 | 4 |",
+ "| a | 5 | 3 |",
+ "| a | | 19 |",
+ "| b | 1 | 3 |",
+ "| b | 2 | 4 |",
+ "| b | 3 | 2 |",
+ "| b | 4 | 5 |",
+ "| b | 5 | 5 |",
+ "| b | | 17 |",
+ "| c | 1 | 4 |",
+ "| c | 2 | 7 |",
+ "| c | 3 | 4 |",
+ "| c | 4 | 4 |",
+ "| c | 5 | 2 |",
+ "| c | | 21 |",
+ "| d | 1 | 7 |",
+ "| d | 2 | 3 |",
+ "| d | 3 | 3 |",
+ "| d | 4 | 3 |",
+ "| d | 5 | 2 |",
+ "| d | | 18 |",
+ "| e | 1 | 3 |",
+ "| e | 2 | 4 |",
+ "| e | 3 | 4 |",
+ "| e | 4 | 7 |",
+ "| e | 5 | 2 |",
+ "| e | | 18 |",
+ "| | 1 | 22 |",
+ "| | 2 | 20 |",
+ "| | 3 | 17 |",
+ "| | 4 | 23 |",
+ "| | 5 | 14 |",
+ "| | | 80 |",
+ "+----+----+---------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
+async fn csv_query_rollup_distinct_count() {
+ let ctx = SessionContext::new();
+ register_aggregate_csv_by_sql(&ctx).await;
+ let sql = "SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY ROLLUP (c1,c2) ORDER BY c1,c2";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+----+---------------------------------------+",
+ "| c1 | c2 | COUNT(DISTINCT aggregate_test_100.c3) |",
+ "+----+----+---------------------------------------+",
+ "| a | 1 | 5 |",
+ "| a | 2 | 3 |",
+ "| a | 3 | 5 |",
+ "| a | 4 | 4 |",
+ "| a | 5 | 3 |",
+ "| a | | 19 |",
+ "| b | 1 | 3 |",
+ "| b | 2 | 4 |",
+ "| b | 3 | 2 |",
+ "| b | 4 | 5 |",
+ "| b | 5 | 5 |",
+ "| b | | 17 |",
+ "| c | 1 | 4 |",
+ "| c | 2 | 7 |",
+ "| c | 3 | 4 |",
+ "| c | 4 | 4 |",
+ "| c | 5 | 2 |",
+ "| c | | 21 |",
+ "| d | 1 | 7 |",
+ "| d | 2 | 3 |",
+ "| d | 3 | 3 |",
+ "| d | 4 | 3 |",
+ "| d | 5 | 2 |",
+ "| d | | 18 |",
+ "| e | 1 | 3 |",
+ "| e | 2 | 4 |",
+ "| e | 3 | 4 |",
+ "| e | 4 | 7 |",
+ "| e | 5 | 2 |",
+ "| e | | 18 |",
+ "| | | 80 |",
+ "+----+----+---------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
+async fn csv_query_rollup_sum_crossjoin() {
+ let ctx = SessionContext::new();
+ register_aggregate_csv_by_sql(&ctx).await;
+ let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY ROLLUP (a.c1, b.c1) ORDER BY a.c1, b.c1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+----+-----------+",
+ "| c1 | c1 | SUM(a.c2) |",
+ "+----+----+-----------+",
+ "| a | a | 1260 |",
+ "| a | b | 1140 |",
+ "| a | c | 1260 |",
+ "| a | d | 1080 |",
+ "| a | e | 1260 |",
+ "| a | | 6000 |",
+ "| b | a | 1302 |",
+ "| b | b | 1178 |",
+ "| b | c | 1302 |",
+ "| b | d | 1116 |",
+ "| b | e | 1302 |",
+ "| b | | 6200 |",
+ "| c | a | 1176 |",
+ "| c | b | 1064 |",
+ "| c | c | 1176 |",
+ "| c | d | 1008 |",
+ "| c | e | 1176 |",
+ "| c | | 5600 |",
+ "| d | a | 924 |",
+ "| d | b | 836 |",
+ "| d | c | 924 |",
+ "| d | d | 792 |",
+ "| d | e | 924 |",
+ "| d | | 4400 |",
+ "| e | a | 1323 |",
+ "| e | b | 1197 |",
+ "| e | c | 1323 |",
+ "| e | d | 1134 |",
+ "| e | e | 1323 |",
+ "| e | | 6300 |",
+ "| | | 28500 |",
+ "+----+----+-----------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+}
+
#[tokio::test]
async fn query_count_without_from() -> Result<()> {
let ctx = SessionContext::new();
@@ -675,6 +1068,59 @@ async fn csv_query_array_agg_with_overflow() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn csv_query_array_cube_agg_with_overflow() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql =
+ "select c1, c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by CUBE (c1,c2) order by c1, c2";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+----+--------+---------------------+--------+--------+----------+",
+ "| c1 | c2 | sum_c3 | avg_c3 | max_c3 | min_c3 | count_c3 |",
+ "+----+----+--------+---------------------+--------+--------+----------+",
+ "| a | 1 | -88 | -17.6 | 83 | -85 | 5 |",
+ "| a | 2 | -46 | -15.333333333333334 | 45 | -48 | 3 |",
+ "| a | 3 | -27 | -4.5 | 17 | -72 | 6 |",
+ "| a | 4 | -128 | -32 | 65 | -101 | 4 |",
+ "| a | 5 | -96 | -32 | 36 | -101 | 3 |",
+ "| a | | -385 | -18.333333333333332 | 83 | -101 | 21 |",
+ "| b | 1 | 95 | 31.666666666666668 | 54 | 12 | 3 |",
+ "| b | 2 | 102 | 25.5 | 68 | -60 | 4 |",
+ "| b | 3 | -84 | -42 | 17 | -101 | 2 |",
+ "| b | 4 | -223 | -44.6 | 47 | -117 | 5 |",
+ "| b | 5 | -1 | -0.2 | 68 | -82 | 5 |",
+ "| b | | -111 | -5.842105263157895 | 68 | -117 | 19 |",
+ "| c | 1 | 190 | 47.5 | 103 | -24 | 4 |",
+ "| c | 2 | -389 | -55.57142857142857 | 29 | -117 | 7 |",
+ "| c | 3 | 190 | 47.5 | 97 | -2 | 4 |",
+ "| c | 4 | -43 | -10.75 | 123 | -90 | 4 |",
+ "| c | 5 | 24 | 12 | 118 | -94 | 2 |",
+ "| c | | -28 | -1.3333333333333333 | 123 | -117 | 21 |",
+ "| d | 1 | -57 | -8.142857142857142 | 125 | -99 | 7 |",
+ "| d | 2 | 328 | 109.33333333333333 | 122 | 93 | 3 |",
+ "| d | 3 | 124 | 41.333333333333336 | 123 | -76 | 3 |",
+ "| d | 4 | 162 | 54 | 102 | 5 | 3 |",
+ "| d | 5 | -99 | -49.5 | -40 | -59 | 2 |",
+ "| d | | 458 | 25.444444444444443 | 125 | -99 | 18 |",
+ "| e | 1 | 227 | 75.66666666666667 | 120 | 36 | 3 |",
+ "| e | 2 | 189 | 37.8 | 97 | -61 | 5 |",
+ "| e | 3 | 192 | 48 | 112 | -95 | 4 |",
+ "| e | 4 | 261 | 37.285714285714285 | 97 | -56 | 7 |",
+ "| e | 5 | -22 | -11 | 64 | -86 | 2 |",
+ "| e | | 847 | 40.333333333333336 | 120 | -95 | 21 |",
+ "| | 1 | 367 | 16.681818181818183 | 125 | -99 | 22 |",
+ "| | 2 | 184 | 8.363636363636363 | 122 | -117 | 22 |",
+ "| | 3 | 395 | 20.789473684210527 | 123 | -101 | 19 |",
+ "| | 4 | 29 | 1.2608695652173914 | 123 | -117 | 23 |",
+ "| | 5 | -194 | -13.857142857142858 | 118 | -101 | 14 |",
+ "| | | 781 | 7.81 | 125 | -117 | 100 |",
+ "+----+----+--------+---------------------+--------+--------+----------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
#[tokio::test]
async fn csv_query_array_agg_distinct() -> Result<()> {
let ctx = SessionContext::new();
@@ -1223,6 +1669,79 @@ async fn count_aggregated() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn count_aggregated_cube() -> Result<()> {
+ let results = execute_with_partition(
+ "SELECT c1, c2, COUNT(c3) FROM test GROUP BY CUBE (c1, c2) ORDER BY c1, c2",
+ 4,
+ )
+ .await?;
+
+ let expected = vec![
+ "+----+----+----------------+",
+ "| c1 | c2 | COUNT(test.c3) |",
+ "+----+----+----------------+",
+ "| | | 40 |",
+ "| | 1 | 4 |",
+ "| | 10 | 4 |",
+ "| | 2 | 4 |",
+ "| | 3 | 4 |",
+ "| | 4 | 4 |",
+ "| | 5 | 4 |",
+ "| | 6 | 4 |",
+ "| | 7 | 4 |",
+ "| | 8 | 4 |",
+ "| | 9 | 4 |",
+ "| 0 | | 10 |",
+ "| 0 | 1 | 1 |",
+ "| 0 | 10 | 1 |",
+ "| 0 | 2 | 1 |",
+ "| 0 | 3 | 1 |",
+ "| 0 | 4 | 1 |",
+ "| 0 | 5 | 1 |",
+ "| 0 | 6 | 1 |",
+ "| 0 | 7 | 1 |",
+ "| 0 | 8 | 1 |",
+ "| 0 | 9 | 1 |",
+ "| 1 | | 10 |",
+ "| 1 | 1 | 1 |",
+ "| 1 | 10 | 1 |",
+ "| 1 | 2 | 1 |",
+ "| 1 | 3 | 1 |",
+ "| 1 | 4 | 1 |",
+ "| 1 | 5 | 1 |",
+ "| 1 | 6 | 1 |",
+ "| 1 | 7 | 1 |",
+ "| 1 | 8 | 1 |",
+ "| 1 | 9 | 1 |",
+ "| 2 | | 10 |",
+ "| 2 | 1 | 1 |",
+ "| 2 | 10 | 1 |",
+ "| 2 | 2 | 1 |",
+ "| 2 | 3 | 1 |",
+ "| 2 | 4 | 1 |",
+ "| 2 | 5 | 1 |",
+ "| 2 | 6 | 1 |",
+ "| 2 | 7 | 1 |",
+ "| 2 | 8 | 1 |",
+ "| 2 | 9 | 1 |",
+ "| 3 | | 10 |",
+ "| 3 | 1 | 1 |",
+ "| 3 | 10 | 1 |",
+ "| 3 | 2 | 1 |",
+ "| 3 | 3 | 1 |",
+ "| 3 | 4 | 1 |",
+ "| 3 | 5 | 1 |",
+ "| 3 | 6 | 1 |",
+ "| 3 | 7 | 1 |",
+ "| 3 | 8 | 1 |",
+ "| 3 | 9 | 1 |",
+ "+----+----+----------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &results);
+ Ok(())
+}
+
#[tokio::test]
async fn simple_avg() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 7cf161697..202605b4d 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -270,6 +270,27 @@ pub enum GroupingSet {
GroupingSets(Vec<Vec<Expr>>),
}
+impl GroupingSet {
+ /// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP` this
+ /// is just the underlying list of exprs. For `GROUPING SET` we need to deduplicate
+ /// the exprs in the underlying sets.
+ pub fn distinct_expr(&self) -> Vec<Expr> {
+ match self {
+ GroupingSet::Rollup(exprs) => exprs.clone(),
+ GroupingSet::Cube(exprs) => exprs.clone(),
+ GroupingSet::GroupingSets(groups) => {
+ let mut exprs: Vec<Expr> = vec![];
+ for exp in groups.iter().flatten() {
+ if !exprs.contains(exp) {
+ exprs.push(exp.clone());
+ }
+ }
+ exprs
+ }
+ }
+ }
+}
+
/// Fixed seed for the hashing so that Ords are consistent across runs
const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 14eeb2c82..76bd9a975 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -17,6 +17,7 @@
//! Functions for creating logical expressions
+use crate::expr::GroupingSet;
use crate::{
aggregate_function, built_in_function, conditional_expressions::CaseBuilder, lit,
logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
@@ -226,6 +227,21 @@ pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
Expr::ScalarSubquery(Subquery { subquery })
}
+/// Create a grouping set
+pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
+ Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
+}
+
+/// Create a grouping set for all combination of `exprs`
+pub fn cube(exprs: Vec<Expr>) -> Expr {
+ Expr::GroupingSet(GroupingSet::Cube(exprs))
+}
+
+/// Create a grouping set for rollup
+pub fn rollup(exprs: Vec<Expr>) -> Expr {
+ Expr::GroupingSet(GroupingSet::Rollup(exprs))
+}
+
// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
// varying arity functions
/// Create an convenience function representing a unary scalar function
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 8d58241ea..083b66c32 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -18,7 +18,9 @@
//! This module provides a builder for creating LogicalPlans
use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs};
-use crate::utils::{columnize_expr, exprlist_to_fields, from_plan};
+use crate::utils::{
+ columnize_expr, exprlist_to_fields, from_plan, grouping_set_to_exprlist,
+};
use crate::{and, binary_expr, Operator};
use crate::{
logical_plan::{
@@ -694,7 +696,10 @@ impl LogicalPlanBuilder {
) -> Result<Self> {
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
- let all_expr = group_expr.iter().chain(aggr_expr.iter());
+
+ let grouping_expr: Vec<Expr> = grouping_set_to_exprlist(group_expr.as_slice())?;
+
+ let all_expr = grouping_expr.iter().chain(aggr_expr.iter());
validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?;
let aggr_schema = DFSchema::new_with_metadata(
exprlist_to_fields(all_expr, &self.plan)?,
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 2120acaed..a85a817a8 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -45,6 +45,22 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result
Ok(())
}
+/// Find all distinct exprs in a list of group by expressions. If the
+/// first element is a `GroupingSet` expression then it must be the only expr.
+pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
+ if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
+ if group_expr.len() > 1 {
+ return Err(DataFusionError::Plan(
+ "Invalid group by expressions, GroupingSet must be the only expression"
+ .to_string(),
+ ));
+ }
+ Ok(grouping_set.distinct_expr())
+ } else {
+ Ok(group_expr.to_vec())
+ }
+}
+
/// Recursively walk an expression tree, collecting the unique set of column names
/// referenced in the expression
struct ColumnNameVisitor<'a> {
diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/projection_push_down.rs
index c9aee1e03..ae2cc4fce 100644
--- a/datafusion/optimizer/src/projection_push_down.rs
+++ b/datafusion/optimizer/src/projection_push_down.rs
@@ -24,6 +24,7 @@ use arrow::error::Result as ArrowResult;
use datafusion_common::{
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ToDFSchema,
};
+use datafusion_expr::utils::grouping_set_to_exprlist;
use datafusion_expr::{
logical_plan::{
builder::{build_join_schema, LogicalPlanBuilder},
@@ -314,7 +315,10 @@ fn optimize_plan(
// * remove any aggregate expression that is not required
// * construct the new set of required columns
- exprlist_to_columns(group_expr, &mut new_required_columns)?;
+ // Find distinct group by exprs in the case where we have a grouping set
+ let all_group_expr: Vec<Expr> = grouping_set_to_exprlist(group_expr)?;
+
+ exprlist_to_columns(&all_group_expr, &mut new_required_columns)?;
// Gather all columns needed for expressions in this Aggregate
let mut new_aggr_expr = Vec::new();
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index c508b9772..80214f302 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -19,6 +19,7 @@
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{DFSchema, Result};
+use datafusion_expr::utils::grouping_set_to_exprlist;
use datafusion_expr::{
col,
logical_plan::{Aggregate, LogicalPlan, Projection},
@@ -62,9 +63,11 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
schema,
group_expr,
}) => {
- if is_single_distinct_agg(plan) {
+ if is_single_distinct_agg(plan) && !contains_grouping_set(group_expr) {
let mut group_fields_set = HashSet::new();
- let mut all_group_args = group_expr.clone();
+ let base_group_expr = grouping_set_to_exprlist(group_expr)?;
+ let mut all_group_args: Vec<Expr> = group_expr.clone();
+
// remove distinct and collection args
let new_aggr_expr = aggr_expr
.iter()
@@ -87,7 +90,9 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
})
.collect::<Vec<_>>();
- let all_field = all_group_args
+ let all_group_expr = grouping_set_to_exprlist(&all_group_args)?;
+
+ let all_field = all_group_expr
.iter()
.map(|expr| expr.to_field(input.schema()).unwrap())
.collect::<Vec<_>>();
@@ -106,7 +111,7 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
let grouped_agg = optimize_children(&grouped_agg);
let final_agg_schema = Arc::new(
DFSchema::new_with_metadata(
- group_expr
+ base_group_expr
.iter()
.chain(new_aggr_expr.iter())
.map(|expr| expr.to_field(&grouped_schema).unwrap())
@@ -115,18 +120,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
)
.unwrap(),
);
- let final_agg = LogicalPlan::Aggregate(Aggregate {
- input: Arc::new(grouped_agg.unwrap()),
- group_expr: group_expr.clone(),
- aggr_expr: new_aggr_expr,
- schema: final_agg_schema.clone(),
- });
// so the aggregates are displayed in the same way even after the rewrite
let mut alias_expr: Vec<Expr> = Vec::new();
- final_agg
- .expressions()
+ base_group_expr
.iter()
+ .chain(new_aggr_expr.iter())
.enumerate()
.for_each(|(i, field)| {
alias_expr.push(columnize_expr(
@@ -135,11 +134,18 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
));
});
+ let final_agg = LogicalPlan::Aggregate(Aggregate {
+ input: Arc::new(grouped_agg.unwrap()),
+ group_expr: group_expr.clone(),
+ aggr_expr: new_aggr_expr,
+ schema: final_agg_schema,
+ });
+
Ok(LogicalPlan::Projection(Projection {
expr: alias_expr,
input: Arc::new(final_agg),
schema: schema.clone(),
- alias: Option::None,
+ alias: None,
}))
} else {
optimize_children(plan)
@@ -185,6 +191,10 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> bool {
}
}
+fn contains_grouping_set(expr: &[Expr]) -> bool {
+ matches!(expr.first(), Some(Expr::GroupingSet(_)))
+}
+
impl OptimizerRule for SingleDistinctToGroupBy {
fn optimize(
&self,
@@ -202,6 +212,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
mod tests {
use super::*;
use crate::test::*;
+ use datafusion_expr::expr::GroupingSet;
use datafusion_expr::{
col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max,
AggregateFunction,
@@ -212,6 +223,7 @@ mod tests {
let optimized_plan = rule
.optimize(plan, &OptimizerConfig::new())
.expect("failed to optimize plan");
+
let formatted_plan = format!("{}", optimized_plan.display_indent_schema());
assert_eq!(formatted_plan, expected);
}
@@ -250,6 +262,69 @@ mod tests {
Ok(())
}
+ // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
+ #[test]
+ fn single_distinct_and_grouping_set() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let grouping_set = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+ vec![col("a")],
+ vec![col("b")],
+ ]));
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
+ .build()?;
+
+ // Should not be optimized
+ let expected = "Aggregate: groupBy=[[GROUPING SETS ((#test.a), (#test.b))]], aggr=[[COUNT(DISTINCT #test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
+ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
+ // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
+ #[test]
+ fn single_distinct_and_cube() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let grouping_set = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
+ .build()?;
+
+ println!("{:?}", plan);
+
+ // Should not be optimized
+ let expected = "Aggregate: groupBy=[[CUBE (#test.a, #test.b)]], aggr=[[COUNT(DISTINCT #test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
+ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
+ // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
+ #[test]
+ fn single_distinct_and_rollup() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let grouping_set =
+ Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
+ .build()?;
+
+ // Should not be optimized
+ let expected = "Aggregate: groupBy=[[ROLLUP (#test.a, #test.b)]], aggr=[[COUNT(DISTINCT #test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
+ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
#[test]
fn single_distinct_expr() -> Result<()> {
let table_scan = test_table_scan()?;
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 4522cd63e..dffc8ec2f 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -300,9 +300,33 @@ message LogicalExprNode {
ScalarUDFExprNode scalar_udf_expr = 20;
GetIndexedField get_indexed_field = 21;
+
+ GroupingSetNode grouping_set = 22;
+
+ CubeNode cube = 23;
+
+ RollupNode rollup = 24;
}
}
+message LogicalExprList {
+ repeated LogicalExprNode expr = 1;
+}
+
+message GroupingSetNode {
+ repeated LogicalExprList expr = 1;
+}
+
+message CubeNode {
+ repeated LogicalExprNode expr = 1;
+}
+
+message RollupNode {
+ repeated LogicalExprNode expr = 1;
+}
+
+
+
message GetIndexedField {
LogicalExprNode expr = 1;
ScalarValue key = 2;
diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs
index c684b785e..279cb8e40 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -20,12 +20,17 @@ use crate::protobuf::plan_type::PlanTypeEnum::{
FinalLogicalPlan, FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan,
OptimizedLogicalPlan, OptimizedPhysicalPlan,
};
-use crate::protobuf::{OptimizedLogicalPlanType, OptimizedPhysicalPlanType};
+use crate::protobuf::{
+ CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType,
+ RollupNode,
+};
use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode};
use datafusion::logical_plan::FunctionRegistry;
use datafusion_common::{
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue,
};
+use datafusion_expr::expr::GroupingSet;
+use datafusion_expr::expr::GroupingSet::GroupingSets;
use datafusion_expr::{
abs, acos, array, ascii, asin, atan, bit_length, btrim, ceil, character_length, chr,
coalesce, concat_expr, concat_ws_expr, cos, date_part, date_trunc, digest, exp,
@@ -1290,6 +1295,32 @@ pub fn parse_expr(
.collect::<Result<Vec<_>, Error>>()?,
})
}
+
+ ExprType::GroupingSet(GroupingSetNode { expr }) => {
+ Ok(Expr::GroupingSet(GroupingSets(
+ expr.iter()
+ .map(|expr_list| {
+ expr_list
+ .expr
+ .iter()
+ .map(|expr| parse_expr(expr, registry))
+ .collect::<Result<Vec<_>, Error>>()
+ })
+ .collect::<Result<Vec<_>, Error>>()?,
+ )))
+ }
+ ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube(
+ expr.iter()
+ .map(|expr| parse_expr(expr, registry))
+ .collect::<Result<Vec<_>, Error>>()?,
+ ))),
+ ExprType::Rollup(RollupNode { expr }) => {
+ Ok(Expr::GroupingSet(GroupingSet::Rollup(
+ expr.iter()
+ .map(|expr| parse_expr(expr, registry))
+ .collect::<Result<Vec<_>, Error>>()?,
+ )))
+ }
}
}
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index 6fe1aac68..f08a00b49 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -62,6 +62,7 @@ mod roundtrip_tests {
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext};
use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
+ use datafusion_expr::expr::GroupingSet;
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
use datafusion_expr::{
col, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, Expr,
@@ -1001,4 +1002,32 @@ mod roundtrip_tests {
roundtrip_expr_test!(test_expr, ctx);
}
+
+ #[test]
+ fn roundtrip_grouping_sets() {
+ let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+ vec![col("a")],
+ vec![col("b")],
+ vec![col("a"), col("b")],
+ ]));
+
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
+ }
+
+ #[test]
+ fn roundtrip_rollup() {
+ let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));
+
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
+ }
+
+ #[test]
+ fn roundtrip_cube() {
+ let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));
+
+ let ctx = SessionContext::new();
+ roundtrip_expr_test!(test_expr, ctx);
+ }
}
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 8df8ff0dd..afe24ea89 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -25,12 +25,14 @@ use crate::protobuf::{
FinalLogicalPlan, FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan,
OptimizedLogicalPlan, OptimizedPhysicalPlan,
},
- EmptyMessage, OptimizedLogicalPlanType, OptimizedPhysicalPlanType,
+ CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType,
+ OptimizedPhysicalPlanType, RollupNode,
};
use arrow::datatypes::{
DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode,
};
use datafusion_common::{Column, DFField, DFSchemaRef, ScalarValue};
+use datafusion_expr::expr::GroupingSet;
use datafusion_expr::{
logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
BuiltInWindowFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFrameBound,
@@ -718,9 +720,42 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
},
))),
},
- Expr::QualifiedWildcard { .. }
- | Expr::TryCast { .. }
- | Expr::GroupingSet(_) => unimplemented!(),
+
+ Expr::GroupingSet(GroupingSet::Cube(exprs)) => Self {
+ expr_type: Some(ExprType::Cube(CubeNode {
+ expr: exprs.iter().map(|expr| expr.try_into()).collect::<Result<
+ Vec<_>,
+ Self::Error,
+ >>(
+ )?,
+ })),
+ },
+ Expr::GroupingSet(GroupingSet::Rollup(exprs)) => Self {
+ expr_type: Some(ExprType::Rollup(RollupNode {
+ expr: exprs.iter().map(|expr| expr.try_into()).collect::<Result<
+ Vec<_>,
+ Self::Error,
+ >>(
+ )?,
+ })),
+ },
+ Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => Self {
+ expr_type: Some(ExprType::GroupingSet(GroupingSetNode {
+ expr: exprs
+ .iter()
+ .map(|expr_list| {
+ Ok(LogicalExprList {
+ expr: expr_list
+ .iter()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, Self::Error>>()?,
+ })
+ })
+ .collect::<Result<Vec<_>, Self::Error>>()?,
+ })),
+ },
+
+ Expr::QualifiedWildcard { .. } | Expr::TryCast { .. } => unimplemented!(),
};
Ok(expr_node)