You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/06/13 19:55:32 UTC

[arrow-datafusion] branch master updated: Support for GROUPING SETS/CUBE/ROLLUP (#2716)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new ca5339bfd Support for GROUPING SETS/CUBE/ROLLUP (#2716)
ca5339bfd is described below

commit ca5339bfd27677c165263aa01f263c8fba886a45
Author: Dan Harris <13...@users.noreply.github.com>
AuthorDate: Mon Jun 13 15:55:28 2022 -0400

    Support for GROUPING SETS/CUBE/ROLLUP (#2716)
    
    * WIP
    
    * Implement for non-row based accumulators
    
    * Non-row aggregations
    
    * Map logical plan and add some basic tests
    
    * Handle grouping sets in various optimize passes.
    
    * Implemented create_cube_expr and create_rollup_expr functions
    
    * Cleanup and ignore SingleDistinctToGroupBy when using grouping sets for now
    
    * Handle grouping sets in SingleDistinctToGroupBy
    
    * Add more tests and burn the boats
    
    * Fix(ish) partitioning
    
    * Serialization for grouping set exprs
    
    * fixed bug with create_cube_expr function
    
    * fixed bug with create_cube_expr function
    
    * Fixed bug in row-based-aggregation
    
    * Added unit tests for test_create_rollup_expr and test_create_cube_expr
    
    * Formatting
    
    * Tests, linter fixes and docs
    
    * Linting
    
    * Better encoding which avoids evaluating grouping expressions redundantly
    
    * Remove commented code
    
    * Apply suggestions from code review
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * PR Comments: Rename PhysicalGroupingSet -> PhysicalGroupBy and clarify doc comment
    
    * Disable single_distinct_to_groupby for grouping sets for now and add unit tests for single distinct queries.
    
    * PR comments
    
    * Remove old comment
    
    * Return PhysicalGroupBy from AggregateExec::group_expr
    
    Co-authored-by: Ryan Tomczik <ry...@coralogix.com>
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 .../src/physical_optimizer/aggregate_statistics.rs |  26 +-
 .../core/src/physical_optimizer/repartition.rs     |   8 +-
 .../core/src/physical_plan/aggregates/hash.rs      | 276 +++++------
 .../core/src/physical_plan/aggregates/mod.rs       | 393 ++++++++++++++--
 .../core/src/physical_plan/aggregates/row_hash.rs  | 288 ++++++------
 datafusion/core/src/physical_plan/planner.rs       | 398 +++++++++++++++-
 datafusion/core/tests/dataframe.rs                 | 219 ++++++++-
 datafusion/core/tests/sql/aggregates.rs            | 519 +++++++++++++++++++++
 datafusion/expr/src/expr.rs                        |  21 +
 datafusion/expr/src/expr_fn.rs                     |  16 +
 datafusion/expr/src/logical_plan/builder.rs        |   9 +-
 datafusion/expr/src/utils.rs                       |  16 +
 datafusion/optimizer/src/projection_push_down.rs   |   6 +-
 .../optimizer/src/single_distinct_to_groupby.rs    | 101 +++-
 datafusion/proto/proto/datafusion.proto            |  24 +
 datafusion/proto/src/from_proto.rs                 |  33 +-
 datafusion/proto/src/lib.rs                        |  29 ++
 datafusion/proto/src/to_proto.rs                   |  43 +-
 18 files changed, 2044 insertions(+), 381 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index bcf4fec07..cafd61d9e 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -265,7 +265,7 @@ mod tests {
 
     use crate::error::Result;
     use crate::logical_plan::Operator;
-    use crate::physical_plan::aggregates::AggregateExec;
+    use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy};
     use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
     use crate::physical_plan::common;
     use crate::physical_plan::expressions::Count;
@@ -407,7 +407,7 @@ mod tests {
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             source,
             Arc::clone(&schema),
@@ -415,7 +415,7 @@ mod tests {
 
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             Arc::new(partial_agg),
             Arc::clone(&schema),
@@ -435,7 +435,7 @@ mod tests {
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             source,
             Arc::clone(&schema),
@@ -443,7 +443,7 @@ mod tests {
 
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             Arc::new(partial_agg),
             Arc::clone(&schema),
@@ -462,7 +462,7 @@ mod tests {
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             source,
             Arc::clone(&schema),
@@ -473,7 +473,7 @@ mod tests {
 
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             Arc::new(coalesce),
             Arc::clone(&schema),
@@ -492,7 +492,7 @@ mod tests {
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             source,
             Arc::clone(&schema),
@@ -503,7 +503,7 @@ mod tests {
 
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             Arc::new(coalesce),
             Arc::clone(&schema),
@@ -533,7 +533,7 @@ mod tests {
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             filter,
             Arc::clone(&schema),
@@ -541,7 +541,7 @@ mod tests {
 
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             Arc::new(partial_agg),
             Arc::clone(&schema),
@@ -576,7 +576,7 @@ mod tests {
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             filter,
             Arc::clone(&schema),
@@ -584,7 +584,7 @@ mod tests {
 
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
-            vec![],
+            PhysicalGroupBy::default(),
             vec![agg.count_expr()],
             Arc::new(partial_agg),
             Arc::clone(&schema),
diff --git a/datafusion/core/src/physical_optimizer/repartition.rs b/datafusion/core/src/physical_optimizer/repartition.rs
index b3b7ba948..e9e14abf6 100644
--- a/datafusion/core/src/physical_optimizer/repartition.rs
+++ b/datafusion/core/src/physical_optimizer/repartition.rs
@@ -242,7 +242,9 @@ mod tests {
     use super::*;
     use crate::datasource::listing::PartitionedFile;
     use crate::datasource::object_store::ObjectStoreUrl;
-    use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
+    use crate::physical_plan::aggregates::{
+        AggregateExec, AggregateMode, PhysicalGroupBy,
+    };
     use crate::physical_plan::expressions::{col, PhysicalSortExpr};
     use crate::physical_plan::file_format::{FileScanConfig, ParquetExec};
     use crate::physical_plan::filter::FilterExec;
@@ -305,12 +307,12 @@ mod tests {
         Arc::new(
             AggregateExec::try_new(
                 AggregateMode::Final,
-                vec![],
+                PhysicalGroupBy::default(),
                 vec![],
                 Arc::new(
                     AggregateExec::try_new(
                         AggregateMode::Partial,
-                        vec![],
+                        PhysicalGroupBy::default(),
                         vec![],
                         input,
                         schema.clone(),
diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs
index 45719260c..ddf9af18f 100644
--- a/datafusion/core/src/physical_plan/aggregates/hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/hash.rs
@@ -29,7 +29,7 @@ use futures::{
 
 use crate::error::Result;
 use crate::physical_plan::aggregates::{
-    evaluate, evaluate_many, AccumulatorItem, AggregateMode,
+    evaluate_group_by, evaluate_many, AccumulatorItem, AggregateMode, PhysicalGroupBy,
 };
 use crate::physical_plan::hash_utils::create_hashes;
 use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
@@ -81,7 +81,7 @@ pub(crate) struct GroupedHashAggregateStream {
     aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
 
     aggr_expr: Vec<Arc<dyn AggregateExpr>>,
-    group_expr: Vec<Arc<dyn PhysicalExpr>>,
+    group_by: PhysicalGroupBy,
 
     baseline_metrics: BaselineMetrics,
     random_state: RandomState,
@@ -93,7 +93,7 @@ impl GroupedHashAggregateStream {
     pub fn new(
         mode: AggregateMode,
         schema: SchemaRef,
-        group_expr: Vec<Arc<dyn PhysicalExpr>>,
+        group_by: PhysicalGroupBy,
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
         input: SendableRecordBatchStream,
         baseline_metrics: BaselineMetrics,
@@ -104,7 +104,7 @@ impl GroupedHashAggregateStream {
         // Assume create_schema() always put group columns in front of aggr columns, we set
         // col_idx_base to group expression count.
         let aggregate_expressions =
-            aggregates::aggregate_expressions(&aggr_expr, &mode, group_expr.len())?;
+            aggregates::aggregate_expressions(&aggr_expr, &mode, group_by.expr.len())?;
 
         timer.done();
 
@@ -113,7 +113,7 @@ impl GroupedHashAggregateStream {
             mode,
             input,
             aggr_expr,
-            group_expr,
+            group_by,
             baseline_metrics,
             aggregate_expressions,
             accumulators: Default::default(),
@@ -144,7 +144,7 @@ impl Stream for GroupedHashAggregateStream {
                     let result = group_aggregate_batch(
                         &this.mode,
                         &this.random_state,
-                        &this.group_expr,
+                        &this.group_by,
                         &this.aggr_expr,
                         batch,
                         &mut this.accumulators,
@@ -165,7 +165,7 @@ impl Stream for GroupedHashAggregateStream {
                     let result = create_batch_from_map(
                         &this.mode,
                         &this.accumulators,
-                        this.group_expr.len(),
+                        this.group_by.expr.len(),
                         &this.schema,
                     )
                     .record_output(&this.baseline_metrics);
@@ -191,152 +191,154 @@ impl RecordBatchStream for GroupedHashAggregateStream {
 fn group_aggregate_batch(
     mode: &AggregateMode,
     random_state: &RandomState,
-    group_expr: &[Arc<dyn PhysicalExpr>],
+    group_by: &PhysicalGroupBy,
     aggr_expr: &[Arc<dyn AggregateExpr>],
     batch: RecordBatch,
     accumulators: &mut Accumulators,
     aggregate_expressions: &[Vec<Arc<dyn PhysicalExpr>>],
 ) -> Result<()> {
     // evaluate the grouping expressions
-    let group_values = evaluate(group_expr, &batch)?;
+    let group_by_values = evaluate_group_by(group_by, &batch)?;
 
     // evaluate the aggregation expressions.
     // We could evaluate them after the `take`, but since we need to evaluate all
     // of them anyways, it is more performant to do it while they are together.
     let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?;
 
-    // 1.1 construct the key from the group values
-    // 1.2 construct the mapping key if it does not exist
-    // 1.3 add the row' index to `indices`
-
-    // track which entries in `accumulators` have rows in this batch to aggregate
-    let mut groups_with_rows = vec![];
-
-    // 1.1 Calculate the group keys for the group values
-    let mut batch_hashes = vec![0; batch.num_rows()];
-    create_hashes(&group_values, random_state, &mut batch_hashes)?;
-
-    for (row, hash) in batch_hashes.into_iter().enumerate() {
-        let Accumulators { map, group_states } = accumulators;
-
-        let entry = map.get_mut(hash, |(_hash, group_idx)| {
-            // verify that a group that we are inserting with hash is
-            // actually the same key value as the group in
-            // existing_idx  (aka group_values @ row)
-            let group_state = &group_states[*group_idx];
-            group_values
-                .iter()
-                .zip(group_state.group_by_values.iter())
-                .all(|(array, scalar)| scalar.eq_array(array, row))
-        });
-
-        match entry {
-            // Existing entry for this group value
-            Some((_hash, group_idx)) => {
-                let group_state = &mut group_states[*group_idx];
-                // 1.3
-                if group_state.indices.is_empty() {
-                    groups_with_rows.push(*group_idx);
-                };
-                group_state.indices.push(row as u32); // remember this row
-            }
-            //  1.2 Need to create new entry
-            None => {
-                let accumulator_set = aggregates::create_accumulators(aggr_expr)?;
+    for grouping_set_values in group_by_values {
+        // 1.1 construct the key from the group values
+        // 1.2 construct the mapping key if it does not exist
+        // 1.3 add the row' index to `indices`
+
+        // track which entries in `accumulators` have rows in this batch to aggregate
+        let mut groups_with_rows = vec![];
+
+        // 1.1 Calculate the group keys for the group values
+        let mut batch_hashes = vec![0; batch.num_rows()];
+        create_hashes(&grouping_set_values, random_state, &mut batch_hashes)?;
 
-                // Copy group values out of arrays into `ScalarValue`s
-                let group_by_values = group_values
+        for (row, hash) in batch_hashes.into_iter().enumerate() {
+            let Accumulators { map, group_states } = accumulators;
+
+            let entry = map.get_mut(hash, |(_hash, group_idx)| {
+                // verify that a group that we are inserting with hash is
+                // actually the same key value as the group in
+                // existing_idx  (aka group_values @ row)
+                let group_state = &group_states[*group_idx];
+                grouping_set_values
                     .iter()
-                    .map(|col| ScalarValue::try_from_array(col, row))
-                    .collect::<Result<Vec<_>>>()?;
-
-                // Add new entry to group_states and save newly created index
-                let group_state = GroupState {
-                    group_by_values: group_by_values.into_boxed_slice(),
-                    accumulator_set,
-                    indices: vec![row as u32], // 1.3
-                };
-                let group_idx = group_states.len();
-                group_states.push(group_state);
-                groups_with_rows.push(group_idx);
-
-                // for hasher function, use precomputed hash value
-                map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
-            }
-        };
-    }
+                    .zip(group_state.group_by_values.iter())
+                    .all(|(array, scalar)| scalar.eq_array(array, row))
+            });
+
+            match entry {
+                // Existing entry for this group value
+                Some((_hash, group_idx)) => {
+                    let group_state = &mut group_states[*group_idx];
+                    // 1.3
+                    if group_state.indices.is_empty() {
+                        groups_with_rows.push(*group_idx);
+                    };
+                    group_state.indices.push(row as u32); // remember this row
+                }
+                //  1.2 Need to create new entry
+                None => {
+                    let accumulator_set = aggregates::create_accumulators(aggr_expr)?;
+
+                    // Copy group values out of arrays into `ScalarValue`s
+                    let group_by_values = grouping_set_values
+                        .iter()
+                        .map(|col| ScalarValue::try_from_array(col, row))
+                        .collect::<Result<Vec<_>>>()?;
+
+                    // Add new entry to group_states and save newly created index
+                    let group_state = GroupState {
+                        group_by_values: group_by_values.into_boxed_slice(),
+                        accumulator_set,
+                        indices: vec![row as u32], // 1.3
+                    };
+                    let group_idx = group_states.len();
+                    group_states.push(group_state);
+                    groups_with_rows.push(group_idx);
+
+                    // for hasher function, use precomputed hash value
+                    map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
+                }
+            };
+        }
 
-    // Collect all indices + offsets based on keys in this vec
-    let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
-    let mut offsets = vec![0];
-    let mut offset_so_far = 0;
-    for group_idx in groups_with_rows.iter() {
-        let indices = &accumulators.group_states[*group_idx].indices;
-        batch_indices.append_slice(indices)?;
-        offset_so_far += indices.len();
-        offsets.push(offset_so_far);
-    }
-    let batch_indices = batch_indices.finish();
+        // Collect all indices + offsets based on keys in this vec
+        let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
+        let mut offsets = vec![0];
+        let mut offset_so_far = 0;
+        for group_idx in groups_with_rows.iter() {
+            let indices = &accumulators.group_states[*group_idx].indices;
+            batch_indices.append_slice(indices)?;
+            offset_so_far += indices.len();
+            offsets.push(offset_so_far);
+        }
+        let batch_indices = batch_indices.finish();
 
-    // `Take` all values based on indices into Arrays
-    let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
-        .iter()
-        .map(|array| {
-            array
-                .iter()
-                .map(|array| {
-                    compute::take(
-                        array.as_ref(),
-                        &batch_indices,
-                        None, // None: no index check
-                    )
-                    .unwrap()
-                })
-                .collect()
-            // 2.3
-        })
-        .collect();
-
-    // 2.1 for each key in this batch
-    // 2.2 for each aggregation
-    // 2.3 `slice` from each of its arrays the keys' values
-    // 2.4 update / merge the accumulator with the values
-    // 2.5 clear indices
-    groups_with_rows
-        .iter()
-        .zip(offsets.windows(2))
-        .try_for_each(|(group_idx, offsets)| {
-            let group_state = &mut accumulators.group_states[*group_idx];
-            // 2.2
-            group_state
-                .accumulator_set
-                .iter_mut()
-                .zip(values.iter())
-                .map(|(accumulator, aggr_array)| {
-                    (
-                        accumulator,
-                        aggr_array
-                            .iter()
-                            .map(|array| {
-                                // 2.3
-                                array.slice(offsets[0], offsets[1] - offsets[0])
-                            })
-                            .collect::<Vec<ArrayRef>>(),
-                    )
-                })
-                .try_for_each(|(accumulator, values)| match mode {
-                    AggregateMode::Partial => accumulator.update_batch(&values),
-                    AggregateMode::FinalPartitioned | AggregateMode::Final => {
-                        // note: the aggregation here is over states, not values, thus the merge
-                        accumulator.merge_batch(&values)
-                    }
-                })
-                // 2.5
-                .and({
-                    group_state.indices.clear();
-                    Ok(())
-                })
-        })?;
+        // `Take` all values based on indices into Arrays
+        let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
+            .iter()
+            .map(|array| {
+                array
+                    .iter()
+                    .map(|array| {
+                        compute::take(
+                            array.as_ref(),
+                            &batch_indices,
+                            None, // None: no index check
+                        )
+                        .unwrap()
+                    })
+                    .collect()
+                // 2.3
+            })
+            .collect();
+
+        // 2.1 for each key in this batch
+        // 2.2 for each aggregation
+        // 2.3 `slice` from each of its arrays the keys' values
+        // 2.4 update / merge the accumulator with the values
+        // 2.5 clear indices
+        groups_with_rows
+            .iter()
+            .zip(offsets.windows(2))
+            .try_for_each(|(group_idx, offsets)| {
+                let group_state = &mut accumulators.group_states[*group_idx];
+                // 2.2
+                group_state
+                    .accumulator_set
+                    .iter_mut()
+                    .zip(values.iter())
+                    .map(|(accumulator, aggr_array)| {
+                        (
+                            accumulator,
+                            aggr_array
+                                .iter()
+                                .map(|array| {
+                                    // 2.3
+                                    array.slice(offsets[0], offsets[1] - offsets[0])
+                                })
+                                .collect::<Vec<ArrayRef>>(),
+                        )
+                    })
+                    .try_for_each(|(accumulator, values)| match mode {
+                        AggregateMode::Partial => accumulator.update_batch(&values),
+                        AggregateMode::FinalPartitioned | AggregateMode::Final => {
+                            // note: the aggregation here is over states, not values, thus the merge
+                            accumulator.merge_batch(&values)
+                        }
+                    })
+                    // 2.5
+                    .and({
+                        group_state.indices.clear();
+                        Ok(())
+                    })
+            })?;
+    }
 
     Ok(())
 }
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs
index abe20cdcb..657b6281a 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -37,6 +37,7 @@ use datafusion_physical_expr::{
     expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr,
 };
 use std::any::Any;
+
 use std::sync::Arc;
 
 mod hash;
@@ -65,13 +66,93 @@ pub enum AggregateMode {
     FinalPartitioned,
 }
 
+/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
+/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
+/// and a single group [false, false].
+/// In the case of `GROUP BY GROUPING SET/CUBE/ROLLUP` the planner will expand the expression
+/// into multiple groups, using null expressions to align each group.
+/// For example, with a group by clause `GROUP BY GROUPING SET ((a,b),(a),(b))` the planner should
+/// create a `PhysicalGroupBy` like
+/// PhysicalGroupBy {
+///     expr: [(col(a), a), (col(b), b)],
+///     null_expr: [(NULL, a), (NULL, b)],
+///     groups: [
+///         [false, false], // (a,b)
+///         [false, true],  // (a) <=> (a, NULL)
+///         [true, false]   // (b) <=> (NULL, b)
+///     ]
+/// }
+#[derive(Clone, Debug, Default)]
+pub struct PhysicalGroupBy {
+    /// Distinct (Physical Expr, Alias) in the grouping set
+    expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+    /// Corresponding NULL expressions for expr
+    null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+    /// Null mask for each group in this grouping set. Each group is
+    /// composed of either one of the group expressions in expr or a null
+    /// expression in null_expr. If groups[i][j] is true, then the the
+    /// j-th expression in the i-th group is NULL, otherwise it is expr[j].
+    groups: Vec<Vec<bool>>,
+}
+
+impl PhysicalGroupBy {
+    /// Create a new `PhysicalGroupBy`
+    pub fn new(
+        expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+        null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+        groups: Vec<Vec<bool>>,
+    ) -> Self {
+        Self {
+            expr,
+            null_expr,
+            groups,
+        }
+    }
+
+    /// Create a GROUPING SET with only a single group. This is the "standard"
+    /// case when building a plan from an expression such as `GROUP BY a,b,c`
+    pub fn new_single(expr: Vec<(Arc<dyn PhysicalExpr>, String)>) -> Self {
+        let num_exprs = expr.len();
+        Self {
+            expr,
+            null_expr: vec![],
+            groups: vec![vec![false; num_exprs]],
+        }
+    }
+
+    /// Returns true if this GROUP BY contains NULL expressions
+    pub fn contains_null(&self) -> bool {
+        self.groups.iter().flatten().any(|is_null| *is_null)
+    }
+
+    /// Returns the group expressions
+    pub fn expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
+        &self.expr
+    }
+
+    /// Returns the null expressions
+    pub fn null_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
+        &self.null_expr
+    }
+
+    /// Returns the group null masks
+    pub fn groups(&self) -> &[Vec<bool>] {
+        &self.groups
+    }
+
+    /// Returns true if this `PhysicalGroupBy` has no group expressions
+    pub fn is_empty(&self) -> bool {
+        self.expr.is_empty()
+    }
+}
+
 /// Hash aggregate execution plan
 #[derive(Debug)]
 pub struct AggregateExec {
     /// Aggregation mode (full, partial)
     mode: AggregateMode,
-    /// Grouping expressions
-    group_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+    /// Group by expressions
+    group_by: PhysicalGroupBy,
     /// Aggregate expressions
     aggr_expr: Vec<Arc<dyn AggregateExpr>>,
     /// Input plan, could be a partial aggregate or the input to the aggregate
@@ -90,18 +171,24 @@ impl AggregateExec {
     /// Create a new hash aggregate execution plan
     pub fn try_new(
         mode: AggregateMode,
-        group_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+        group_by: PhysicalGroupBy,
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
         input: Arc<dyn ExecutionPlan>,
         input_schema: SchemaRef,
     ) -> Result<Self> {
-        let schema = create_schema(&input.schema(), &group_expr, &aggr_expr, mode)?;
+        let schema = create_schema(
+            &input.schema(),
+            &group_by.expr,
+            &aggr_expr,
+            group_by.contains_null(),
+            mode,
+        )?;
 
         let schema = Arc::new(schema);
 
         Ok(AggregateExec {
             mode,
-            group_expr,
+            group_by,
             aggr_expr,
             input,
             schema,
@@ -116,15 +203,16 @@ impl AggregateExec {
     }
 
     /// Grouping expressions
-    pub fn group_expr(&self) -> &[(Arc<dyn PhysicalExpr>, String)] {
-        &self.group_expr
+    pub fn group_expr(&self) -> &PhysicalGroupBy {
+        &self.group_by
     }
 
     /// Grouping expressions as they occur in the output schema
     pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
         // Update column indices. Since the group by columns come first in the output schema, their
         // indices are simply 0..self.group_expr(len).
-        self.group_expr
+        self.group_by
+            .expr()
             .iter()
             .enumerate()
             .map(|(index, (_col, name))| {
@@ -149,7 +237,7 @@ impl AggregateExec {
     }
 
     fn row_aggregate_supported(&self) -> bool {
-        let group_schema = group_schema(&self.schema, self.group_expr.len());
+        let group_schema = group_schema(&self.schema, self.group_by.expr.len());
         row_supported(&group_schema, RowType::Compact)
             && accumulator_v2_supported(&self.aggr_expr)
     }
@@ -178,7 +266,7 @@ impl ExecutionPlan for AggregateExec {
         match &self.mode {
             AggregateMode::Partial => Distribution::UnspecifiedDistribution,
             AggregateMode::FinalPartitioned => Distribution::HashPartitioned(
-                self.group_expr.iter().map(|x| x.0.clone()).collect(),
+                self.group_by.expr.iter().map(|x| x.0.clone()).collect(),
             ),
             AggregateMode::Final => Distribution::SinglePartition,
         }
@@ -198,7 +286,7 @@ impl ExecutionPlan for AggregateExec {
     ) -> Result<Arc<dyn ExecutionPlan>> {
         Ok(Arc::new(AggregateExec::try_new(
             self.mode,
-            self.group_expr.clone(),
+            self.group_by.clone(),
             self.aggr_expr.clone(),
             children[0].clone(),
             self.input_schema.clone(),
@@ -211,11 +299,10 @@ impl ExecutionPlan for AggregateExec {
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
         let input = self.input.execute(partition, context)?;
-        let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect();
 
         let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
 
-        if self.group_expr.is_empty() {
+        if self.group_by.expr.is_empty() {
             Ok(Box::pin(AggregateStream::new(
                 self.mode,
                 self.schema.clone(),
@@ -227,7 +314,7 @@ impl ExecutionPlan for AggregateExec {
             Ok(Box::pin(GroupedHashAggregateStreamV2::new(
                 self.mode,
                 self.schema.clone(),
-                group_expr,
+                self.group_by.clone(),
                 self.aggr_expr.clone(),
                 input,
                 baseline_metrics,
@@ -236,7 +323,7 @@ impl ExecutionPlan for AggregateExec {
             Ok(Box::pin(GroupedHashAggregateStream::new(
                 self.mode,
                 self.schema.clone(),
-                group_expr,
+                self.group_by.clone(),
                 self.aggr_expr.clone(),
                 input,
                 baseline_metrics,
@@ -256,18 +343,53 @@ impl ExecutionPlan for AggregateExec {
         match t {
             DisplayFormatType::Default => {
                 write!(f, "AggregateExec: mode={:?}", self.mode)?;
-                let g: Vec<String> = self
-                    .group_expr
-                    .iter()
-                    .map(|(e, alias)| {
-                        let e = e.to_string();
-                        if &e != alias {
-                            format!("{} as {}", e, alias)
-                        } else {
-                            e
-                        }
-                    })
-                    .collect();
+                let g: Vec<String> = if self.group_by.groups.len() == 1 {
+                    self.group_by
+                        .expr
+                        .iter()
+                        .map(|(e, alias)| {
+                            let e = e.to_string();
+                            if &e != alias {
+                                format!("{} as {}", e, alias)
+                            } else {
+                                e
+                            }
+                        })
+                        .collect()
+                } else {
+                    self.group_by
+                        .groups
+                        .iter()
+                        .map(|group| {
+                            let terms = group
+                                .iter()
+                                .enumerate()
+                                .map(|(idx, is_null)| {
+                                    if *is_null {
+                                        let (e, alias) = &self.group_by.null_expr[idx];
+                                        let e = e.to_string();
+                                        if &e != alias {
+                                            format!("{} as {}", e, alias)
+                                        } else {
+                                            e
+                                        }
+                                    } else {
+                                        let (e, alias) = &self.group_by.expr[idx];
+                                        let e = e.to_string();
+                                        if &e != alias {
+                                            format!("{} as {}", e, alias)
+                                        } else {
+                                            e
+                                        }
+                                    }
+                                })
+                                .collect::<Vec<String>>()
+                                .join(", ");
+                            format!("({})", terms)
+                        })
+                        .collect()
+                };
+
                 write!(f, ", gby=[{}]", g.join(", "))?;
 
                 let a: Vec<String> = self
@@ -289,7 +411,7 @@ impl ExecutionPlan for AggregateExec {
         // - aggregations somtimes also preserve invariants such as min, max...
         match self.mode {
             AggregateMode::Final | AggregateMode::FinalPartitioned
-                if self.group_expr.is_empty() =>
+                if self.group_by.expr.is_empty() =>
             {
                 Statistics {
                     num_rows: Some(1),
@@ -306,6 +428,7 @@ fn create_schema(
     input_schema: &Schema,
     group_expr: &[(Arc<dyn PhysicalExpr>, String)],
     aggr_expr: &[Arc<dyn AggregateExpr>],
+    contains_null_expr: bool,
     mode: AggregateMode,
 ) -> datafusion_common::Result<Schema> {
     let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
@@ -313,7 +436,10 @@ fn create_schema(
         fields.push(Field::new(
             name,
             expr.data_type(input_schema)?,
-            expr.nullable(input_schema)?,
+            // In cases where we have multiple grouping sets, we will use NULL expressions in
+            // order to align the grouping sets. So the field must be nullable even if the underlying
+            // schema field is not.
+            contains_null_expr || expr.nullable(input_schema)?,
         ))
     }
 
@@ -469,11 +595,54 @@ fn evaluate_many(
         .collect::<Result<Vec<_>>>()
 }
 
+fn evaluate_group_by(
+    group_by: &PhysicalGroupBy,
+    batch: &RecordBatch,
+) -> Result<Vec<Vec<ArrayRef>>> {
+    let exprs: Vec<ArrayRef> = group_by
+        .expr
+        .iter()
+        .map(|(expr, _)| {
+            let value = expr.evaluate(batch)?;
+            Ok(value.into_array(batch.num_rows()))
+        })
+        .collect::<Result<Vec<_>>>()?;
+
+    let null_exprs: Vec<ArrayRef> = group_by
+        .null_expr
+        .iter()
+        .map(|(expr, _)| {
+            let value = expr.evaluate(batch)?;
+            Ok(value.into_array(batch.num_rows()))
+        })
+        .collect::<Result<Vec<_>>>()?;
+
+    Ok(group_by
+        .groups
+        .iter()
+        .map(|group| {
+            group
+                .iter()
+                .enumerate()
+                .map(|(idx, is_null)| {
+                    if *is_null {
+                        null_exprs[idx].clone()
+                    } else {
+                        exprs[idx].clone()
+                    }
+                })
+                .collect()
+        })
+        .collect())
+}
+
 #[cfg(test)]
 mod tests {
     use crate::execution::context::TaskContext;
     use crate::from_slice::FromSlice;
-    use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
+    use crate::physical_plan::aggregates::{
+        AggregateExec, AggregateMode, PhysicalGroupBy,
+    };
     use crate::physical_plan::expressions::{col, Avg};
     use crate::test::assert_is_pending;
     use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
@@ -482,7 +651,8 @@ mod tests {
     use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
     use arrow::error::Result as ArrowResult;
     use arrow::record_batch::RecordBatch;
-    use datafusion_common::{DataFusionError, Result};
+    use datafusion_common::{DataFusionError, Result, ScalarValue};
+    use datafusion_physical_expr::expressions::{lit, Count};
     use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
     use futures::{FutureExt, Stream};
     use std::any::Any;
@@ -528,12 +698,129 @@ mod tests {
         )
     }
 
+    async fn check_grouping_sets(input: Arc<dyn ExecutionPlan>) -> Result<()> {
+        let input_schema = input.schema();
+
+        let grouping_set = PhysicalGroupBy {
+            expr: vec![
+                (col("a", &input_schema)?, "a".to_string()),
+                (col("b", &input_schema)?, "b".to_string()),
+            ],
+            null_expr: vec![
+                (lit(ScalarValue::UInt32(None)), "a".to_string()),
+                (lit(ScalarValue::Float64(None)), "b".to_string()),
+            ],
+            groups: vec![
+                vec![false, true],  // (a, NULL)
+                vec![true, false],  // (NULL, b)
+                vec![false, false], // (a,b)
+            ],
+        };
+
+        let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Count::new(
+            lit(ScalarValue::Int8(Some(1))),
+            "COUNT(1)".to_string(),
+            DataType::Int64,
+        ))];
+
+        let session_ctx = SessionContext::new();
+        let task_ctx = session_ctx.task_ctx();
+
+        let partial_aggregate = Arc::new(AggregateExec::try_new(
+            AggregateMode::Partial,
+            grouping_set.clone(),
+            aggregates.clone(),
+            input,
+            input_schema.clone(),
+        )?);
+
+        let result =
+            common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;
+
+        let expected = vec![
+            "+---+---+-----------------+",
+            "| a | b | COUNT(1)[count] |",
+            "+---+---+-----------------+",
+            "|   | 1 | 2               |",
+            "|   | 2 | 2               |",
+            "|   | 3 | 2               |",
+            "|   | 4 | 2               |",
+            "| 2 |   | 2               |",
+            "| 2 | 1 | 2               |",
+            "| 3 |   | 3               |",
+            "| 3 | 2 | 2               |",
+            "| 3 | 3 | 1               |",
+            "| 4 |   | 3               |",
+            "| 4 | 3 | 1               |",
+            "| 4 | 4 | 2               |",
+            "+---+---+-----------------+",
+        ];
+        assert_batches_sorted_eq!(expected, &result);
+
+        let groups = partial_aggregate.group_expr().expr().to_vec();
+
+        let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
+
+        let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = groups
+            .iter()
+            .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
+            .collect::<Result<_>>()?;
+
+        let final_grouping_set = PhysicalGroupBy::new_single(final_group);
+
+        let merged_aggregate = Arc::new(AggregateExec::try_new(
+            AggregateMode::Final,
+            final_grouping_set,
+            aggregates,
+            merge,
+            input_schema,
+        )?);
+
+        let result =
+            common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?;
+        assert_eq!(result.len(), 1);
+
+        let batch = &result[0];
+        assert_eq!(batch.num_columns(), 3);
+        assert_eq!(batch.num_rows(), 12);
+
+        let expected = vec![
+            "+---+---+----------+",
+            "| a | b | COUNT(1) |",
+            "+---+---+----------+",
+            "|   | 1 | 2        |",
+            "|   | 2 | 2        |",
+            "|   | 3 | 2        |",
+            "|   | 4 | 2        |",
+            "| 2 |   | 2        |",
+            "| 2 | 1 | 2        |",
+            "| 3 |   | 3        |",
+            "| 3 | 2 | 2        |",
+            "| 3 | 3 | 1        |",
+            "| 4 |   | 3        |",
+            "| 4 | 3 | 1        |",
+            "| 4 | 4 | 2        |",
+            "+---+---+----------+",
+        ];
+
+        assert_batches_sorted_eq!(&expected, &result);
+
+        let metrics = merged_aggregate.metrics().unwrap();
+        let output_rows = metrics.output_rows().unwrap();
+        assert_eq!(12, output_rows);
+
+        Ok(())
+    }
+
     /// build the aggregates on the data from some_data() and check the results
     async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> {
         let input_schema = input.schema();
 
-        let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
-            vec![(col("a", &input_schema)?, "a".to_string())];
+        let grouping_set = PhysicalGroupBy {
+            expr: vec![(col("a", &input_schema)?, "a".to_string())],
+            null_expr: vec![],
+            groups: vec![vec![false]],
+        };
 
         let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
             col("b", &input_schema)?,
@@ -546,7 +833,7 @@ mod tests {
 
         let partial_aggregate = Arc::new(AggregateExec::try_new(
             AggregateMode::Partial,
-            groups.clone(),
+            grouping_set.clone(),
             aggregates.clone(),
             input,
             input_schema.clone(),
@@ -568,17 +855,17 @@ mod tests {
 
         let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
 
-        let final_group: Vec<Arc<dyn PhysicalExpr>> = (0..groups.len())
-            .map(|i| col(&groups[i].1, &input_schema))
+        let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = grouping_set
+            .expr
+            .iter()
+            .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
             .collect::<Result<_>>()?;
 
+        let final_grouping_set = PhysicalGroupBy::new_single(final_group);
+
         let merged_aggregate = Arc::new(AggregateExec::try_new(
             AggregateMode::Final,
-            final_group
-                .iter()
-                .enumerate()
-                .map(|(i, expr)| (expr.clone(), groups[i].1.clone()))
-                .collect(),
+            final_grouping_set,
             aggregates,
             merge,
             input_schema,
@@ -719,6 +1006,14 @@ mod tests {
         check_aggregates(input).await
     }
 
+    #[tokio::test]
+    async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
+        let input: Arc<dyn ExecutionPlan> =
+            Arc::new(TestYieldingExec { yield_first: false });
+
+        check_grouping_sets(input).await
+    }
+
     #[tokio::test]
     async fn aggregate_source_with_yielding() -> Result<()> {
         let input: Arc<dyn ExecutionPlan> =
@@ -727,6 +1022,14 @@ mod tests {
         check_aggregates(input).await
     }
 
+    #[tokio::test]
+    async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
+        let input: Arc<dyn ExecutionPlan> =
+            Arc::new(TestYieldingExec { yield_first: true });
+
+        check_grouping_sets(input).await
+    }
+
     #[tokio::test]
     async fn test_drop_cancel_without_groups() -> Result<()> {
         let session_ctx = SessionContext::new();
@@ -734,7 +1037,7 @@ mod tests {
         let schema =
             Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
 
-        let groups = vec![];
+        let groups = PhysicalGroupBy::default();
 
         let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
             col("a", &schema)?,
@@ -771,8 +1074,8 @@ mod tests {
             Field::new("b", DataType::Float32, true),
         ]));
 
-        let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
-            vec![(col("a", &schema)?, "a".to_string())];
+        let groups =
+            PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]);
 
         let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
             col("b", &schema)?,
@@ -784,7 +1087,7 @@ mod tests {
         let refs = blocking_exec.refs();
         let aggregate_exec = Arc::new(AggregateExec::try_new(
             AggregateMode::Partial,
-            groups.clone(),
+            groups,
             aggregates.clone(),
             blocking_exec,
             schema,
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index e364048e7..5353bc745 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -29,7 +29,8 @@ use futures::{
 
 use crate::error::Result;
 use crate::physical_plan::aggregates::{
-    evaluate, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode,
+    evaluate_group_by, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode,
+    PhysicalGroupBy,
 };
 use crate::physical_plan::hash_utils::create_row_hashes;
 use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
@@ -75,7 +76,7 @@ pub(crate) struct GroupedHashAggregateStreamV2 {
     aggr_state: AggregationState,
     aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
 
-    group_expr: Vec<Arc<dyn PhysicalExpr>>,
+    group_by: PhysicalGroupBy,
     accumulators: Vec<AccumulatorItemV2>,
 
     group_schema: SchemaRef,
@@ -100,7 +101,7 @@ impl GroupedHashAggregateStreamV2 {
     pub fn new(
         mode: AggregateMode,
         schema: SchemaRef,
-        group_expr: Vec<Arc<dyn PhysicalExpr>>,
+        group_by: PhysicalGroupBy,
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
         input: SendableRecordBatchStream,
         baseline_metrics: BaselineMetrics,
@@ -111,11 +112,11 @@ impl GroupedHashAggregateStreamV2 {
         // Assume create_schema() always put group columns in front of aggr columns, we set
         // col_idx_base to group expression count.
         let aggregate_expressions =
-            aggregates::aggregate_expressions(&aggr_expr, &mode, group_expr.len())?;
+            aggregates::aggregate_expressions(&aggr_expr, &mode, group_by.expr.len())?;
 
         let accumulators = aggregates::create_accumulators_v2(&aggr_expr)?;
 
-        let group_schema = group_schema(&schema, group_expr.len());
+        let group_schema = group_schema(&schema, group_by.expr.len());
         let aggr_schema = aggr_state_schema(&aggr_expr)?;
 
         let aggr_layout = Arc::new(RowLayout::new(&aggr_schema, RowType::WordAligned));
@@ -125,7 +126,7 @@ impl GroupedHashAggregateStreamV2 {
             schema,
             mode,
             input,
-            group_expr,
+            group_by,
             accumulators,
             group_schema,
             aggr_schema,
@@ -160,7 +161,7 @@ impl Stream for GroupedHashAggregateStreamV2 {
                     let result = group_aggregate_batch(
                         &this.mode,
                         &this.random_state,
-                        &this.group_expr,
+                        &this.group_by,
                         &mut this.accumulators,
                         &this.group_schema,
                         this.aggr_layout.clone(),
@@ -212,7 +213,7 @@ impl RecordBatchStream for GroupedHashAggregateStreamV2 {
 fn group_aggregate_batch(
     mode: &AggregateMode,
     random_state: &RandomState,
-    group_expr: &[Arc<dyn PhysicalExpr>],
+    grouping_set: &PhysicalGroupBy,
     accumulators: &mut [AccumulatorItemV2],
     group_schema: &Schema,
     state_layout: Arc<RowLayout>,
@@ -221,142 +222,145 @@ fn group_aggregate_batch(
     aggregate_expressions: &[Vec<Arc<dyn PhysicalExpr>>],
 ) -> Result<()> {
     // evaluate the grouping expressions
-    let group_values = evaluate(group_expr, &batch)?;
-    let group_rows: Vec<Vec<u8>> = create_group_rows(group_values, group_schema);
-
-    // evaluate the aggregation expressions.
-    // We could evaluate them after the `take`, but since we need to evaluate all
-    // of them anyways, it is more performant to do it while they are together.
-    let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?;
-
-    // 1.1 construct the key from the group values
-    // 1.2 construct the mapping key if it does not exist
-    // 1.3 add the row' index to `indices`
-
-    // track which entries in `aggr_state` have rows in this batch to aggregate
-    let mut groups_with_rows = vec![];
-
-    // 1.1 Calculate the group keys for the group values
-    let mut batch_hashes = vec![0; batch.num_rows()];
-    create_row_hashes(&group_rows, random_state, &mut batch_hashes)?;
-
-    for (row, hash) in batch_hashes.into_iter().enumerate() {
-        let AggregationState { map, group_states } = aggr_state;
-
-        let entry = map.get_mut(hash, |(_hash, group_idx)| {
-            // verify that a group that we are inserting with hash is
-            // actually the same key value as the group in
-            // existing_idx  (aka group_values @ row)
-            let group_state = &group_states[*group_idx];
-            group_rows[row] == group_state.group_by_values
-        });
-
-        match entry {
-            // Existing entry for this group value
-            Some((_hash, group_idx)) => {
-                let group_state = &mut group_states[*group_idx];
-                // 1.3
-                if group_state.indices.is_empty() {
-                    groups_with_rows.push(*group_idx);
-                };
-                group_state.indices.push(row as u32); // remember this row
-            }
-            //  1.2 Need to create new entry
-            None => {
-                // Add new entry to group_states and save newly created index
-                let group_state = RowGroupState {
-                    group_by_values: group_rows[row].clone(),
-                    aggregation_buffer: vec![0; state_layout.fixed_part_width()],
-                    indices: vec![row as u32], // 1.3
-                };
-                let group_idx = group_states.len();
-                group_states.push(group_state);
-                groups_with_rows.push(group_idx);
-
-                // for hasher function, use precomputed hash value
-                map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
-            }
-        };
-    }
-
-    // Collect all indices + offsets based on keys in this vec
-    let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
-    let mut offsets = vec![0];
-    let mut offset_so_far = 0;
-    for group_idx in groups_with_rows.iter() {
-        let indices = &aggr_state.group_states[*group_idx].indices;
-        batch_indices.append_slice(indices)?;
-        offset_so_far += indices.len();
-        offsets.push(offset_so_far);
-    }
-    let batch_indices = batch_indices.finish();
+    let grouping_by_values = evaluate_group_by(grouping_set, &batch)?;
+
+    for group_values in grouping_by_values {
+        let group_rows: Vec<Vec<u8>> = create_group_rows(group_values, group_schema);
+
+        // evaluate the aggregation expressions.
+        // We could evaluate them after the `take`, but since we need to evaluate all
+        // of them anyways, it is more performant to do it while they are together.
+        let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?;
+
+        // 1.1 construct the key from the group values
+        // 1.2 construct the mapping key if it does not exist
+        // 1.3 add the row' index to `indices`
+
+        // track which entries in `aggr_state` have rows in this batch to aggregate
+        let mut groups_with_rows = vec![];
+
+        // 1.1 Calculate the group keys for the group values
+        let mut batch_hashes = vec![0; batch.num_rows()];
+        create_row_hashes(&group_rows, random_state, &mut batch_hashes)?;
+
+        for (row, hash) in batch_hashes.into_iter().enumerate() {
+            let AggregationState { map, group_states } = aggr_state;
+
+            let entry = map.get_mut(hash, |(_hash, group_idx)| {
+                // verify that a group that we are inserting with hash is
+                // actually the same key value as the group in
+                // existing_idx  (aka group_values @ row)
+                let group_state = &group_states[*group_idx];
+                group_rows[row] == group_state.group_by_values
+            });
+
+            match entry {
+                // Existing entry for this group value
+                Some((_hash, group_idx)) => {
+                    let group_state = &mut group_states[*group_idx];
+                    // 1.3
+                    if group_state.indices.is_empty() {
+                        groups_with_rows.push(*group_idx);
+                    };
+                    group_state.indices.push(row as u32); // remember this row
+                }
+                //  1.2 Need to create new entry
+                None => {
+                    // Add new entry to group_states and save newly created index
+                    let group_state = RowGroupState {
+                        group_by_values: group_rows[row].clone(),
+                        aggregation_buffer: vec![0; state_layout.fixed_part_width()],
+                        indices: vec![row as u32], // 1.3
+                    };
+                    let group_idx = group_states.len();
+                    group_states.push(group_state);
+                    groups_with_rows.push(group_idx);
+
+                    // for hasher function, use precomputed hash value
+                    map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
+                }
+            };
+        }
 
-    // `Take` all values based on indices into Arrays
-    let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
-        .iter()
-        .map(|array| {
-            array
-                .iter()
-                .map(|array| {
-                    compute::take(
-                        array.as_ref(),
-                        &batch_indices,
-                        None, // None: no index check
-                    )
-                    .unwrap()
-                })
-                .collect()
-            // 2.3
-        })
-        .collect();
-
-    // 2.1 for each key in this batch
-    // 2.2 for each aggregation
-    // 2.3 `slice` from each of its arrays the keys' values
-    // 2.4 update / merge the accumulator with the values
-    // 2.5 clear indices
-    groups_with_rows
-        .iter()
-        .zip(offsets.windows(2))
-        .try_for_each(|(group_idx, offsets)| {
-            let group_state = &mut aggr_state.group_states[*group_idx];
-            // 2.2
-            accumulators
-                .iter_mut()
-                .zip(values.iter())
-                .map(|(accumulator, aggr_array)| {
-                    (
-                        accumulator,
-                        aggr_array
-                            .iter()
-                            .map(|array| {
-                                // 2.3
-                                array.slice(offsets[0], offsets[1] - offsets[0])
-                            })
-                            .collect::<Vec<ArrayRef>>(),
-                    )
-                })
-                .try_for_each(|(accumulator, values)| {
-                    let mut state_accessor =
-                        RowAccessor::new_from_layout(state_layout.clone());
-                    state_accessor
-                        .point_to(0, group_state.aggregation_buffer.as_mut_slice());
-                    match mode {
-                        AggregateMode::Partial => {
-                            accumulator.update_batch(&values, &mut state_accessor)
-                        }
-                        AggregateMode::FinalPartitioned | AggregateMode::Final => {
-                            // note: the aggregation here is over states, not values, thus the merge
-                            accumulator.merge_batch(&values, &mut state_accessor)
+        // Collect all indices + offsets based on keys in this vec
+        let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
+        let mut offsets = vec![0];
+        let mut offset_so_far = 0;
+        for group_idx in groups_with_rows.iter() {
+            let indices = &aggr_state.group_states[*group_idx].indices;
+            batch_indices.append_slice(indices)?;
+            offset_so_far += indices.len();
+            offsets.push(offset_so_far);
+        }
+        let batch_indices = batch_indices.finish();
+
+        // `Take` all values based on indices into Arrays
+        let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
+            .iter()
+            .map(|array| {
+                array
+                    .iter()
+                    .map(|array| {
+                        compute::take(
+                            array.as_ref(),
+                            &batch_indices,
+                            None, // None: no index check
+                        )
+                        .unwrap()
+                    })
+                    .collect()
+                // 2.3
+            })
+            .collect();
+
+        // 2.1 for each key in this batch
+        // 2.2 for each aggregation
+        // 2.3 `slice` from each of its arrays the keys' values
+        // 2.4 update / merge the accumulator with the values
+        // 2.5 clear indices
+        groups_with_rows
+            .iter()
+            .zip(offsets.windows(2))
+            .try_for_each(|(group_idx, offsets)| {
+                let group_state = &mut aggr_state.group_states[*group_idx];
+                // 2.2
+                accumulators
+                    .iter_mut()
+                    .zip(values.iter())
+                    .map(|(accumulator, aggr_array)| {
+                        (
+                            accumulator,
+                            aggr_array
+                                .iter()
+                                .map(|array| {
+                                    // 2.3
+                                    array.slice(offsets[0], offsets[1] - offsets[0])
+                                })
+                                .collect::<Vec<ArrayRef>>(),
+                        )
+                    })
+                    .try_for_each(|(accumulator, values)| {
+                        let mut state_accessor =
+                            RowAccessor::new_from_layout(state_layout.clone());
+                        state_accessor
+                            .point_to(0, group_state.aggregation_buffer.as_mut_slice());
+                        match mode {
+                            AggregateMode::Partial => {
+                                accumulator.update_batch(&values, &mut state_accessor)
+                            }
+                            AggregateMode::FinalPartitioned | AggregateMode::Final => {
+                                // note: the aggregation here is over states, not values, thus the merge
+                                accumulator.merge_batch(&values, &mut state_accessor)
+                            }
                         }
-                    }
-                })
-                // 2.5
-                .and({
-                    group_state.indices.clear();
-                    Ok(())
-                })
-        })?;
+                    })
+                    // 2.5
+                    .and({
+                        group_state.indices.clear();
+                        Ok(())
+                    })
+            })?;
+    }
 
     Ok(())
 }
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 60cc3b8de..14cdee301 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -37,7 +37,7 @@ use crate::logical_plan::{
 use crate::logical_plan::{Limit, Values};
 use crate::physical_expr::create_physical_expr;
 use crate::physical_optimizer::optimizer::PhysicalOptimizerRule;
-use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
+use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
 use crate::physical_plan::cross_join::CrossJoinExec;
 use crate::physical_plan::explain::ExplainExec;
 use crate::physical_plan::expressions::{Column, PhysicalSortExpr};
@@ -58,10 +58,13 @@ use arrow::compute::SortOptions;
 use arrow::datatypes::DataType;
 use arrow::datatypes::{Schema, SchemaRef};
 use async_trait::async_trait;
+use datafusion_common::ScalarValue;
 use datafusion_expr::{expr::GroupingSet, utils::expr_to_columns};
+use datafusion_physical_expr::expressions::Literal;
 use datafusion_sql::utils::window_expr_common_partition_keys;
 use futures::future::BoxFuture;
 use futures::{FutureExt, StreamExt, TryStreamExt};
+use itertools::Itertools;
 use log::{debug, trace};
 use std::collections::{HashMap, HashSet};
 use std::fmt::Write;
@@ -535,20 +538,12 @@ impl DefaultPhysicalPlanner {
                     let physical_input_schema = input_exec.schema();
                     let logical_input_schema = input.as_ref().schema();
 
-                    let groups = group_expr
-                        .iter()
-                        .map(|e| {
-                            tuple_err((
-                                self.create_physical_expr(
-                                    e,
-                                    logical_input_schema,
-                                    &physical_input_schema,
-                                    session_state,
-                                ),
-                                physical_name(e),
-                            ))
-                        })
-                        .collect::<Result<Vec<_>>>()?;
+                    let groups = self.create_grouping_physical_expr(
+                        group_expr,
+                        logical_input_schema,
+                        &physical_input_schema,
+                        session_state)?;
+
                     let aggregates = aggr_expr
                         .iter()
                         .map(|e| {
@@ -574,6 +569,7 @@ impl DefaultPhysicalPlanner {
 
                     // TODO: dictionary type not yet supported in Hash Repartition
                     let contains_dict = groups
+                        .expr()
                         .iter()
                         .flat_map(|x| x.0.data_type(physical_input_schema.as_ref()))
                         .any(|x| matches!(x, DataType::Dictionary(_, _)));
@@ -603,13 +599,17 @@ impl DefaultPhysicalPlanner {
                         (initial_aggr, AggregateMode::Final)
                     };
 
-                    Ok(Arc::new(AggregateExec::try_new(
-                        next_partition_mode,
+                    let final_grouping_set = PhysicalGroupBy::new_single(
                         final_group
                             .iter()
                             .enumerate()
-                            .map(|(i, expr)| (expr.clone(), groups[i].1.clone()))
-                            .collect(),
+                            .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone()))
+                            .collect()
+                    );
+
+                    Ok(Arc::new(AggregateExec::try_new(
+                        next_partition_mode,
+                        final_grouping_set,
                         aggregates,
                         initial_aggr,
                         physical_input_schema.clone(),
@@ -1001,6 +1001,261 @@ impl DefaultPhysicalPlanner {
             exec_plan
         }.boxed()
     }
+
+    fn create_grouping_physical_expr(
+        &self,
+        group_expr: &[Expr],
+        input_dfschema: &DFSchema,
+        input_schema: &Schema,
+        session_state: &SessionState,
+    ) -> Result<PhysicalGroupBy> {
+        if group_expr.len() == 1 {
+            match &group_expr[0] {
+                Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
+                    merge_grouping_set_physical_expr(
+                        grouping_sets,
+                        input_dfschema,
+                        input_schema,
+                        session_state,
+                    )
+                }
+                Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr(
+                    exprs,
+                    input_dfschema,
+                    input_schema,
+                    session_state,
+                ),
+                Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
+                    create_rollup_physical_expr(
+                        exprs,
+                        input_dfschema,
+                        input_schema,
+                        session_state,
+                    )
+                }
+                expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err((
+                    self.create_physical_expr(
+                        expr,
+                        input_dfschema,
+                        input_schema,
+                        session_state,
+                    ),
+                    physical_name(expr),
+                ))?])),
+            }
+        } else {
+            Ok(PhysicalGroupBy::new_single(
+                group_expr
+                    .iter()
+                    .map(|e| {
+                        tuple_err((
+                            self.create_physical_expr(
+                                e,
+                                input_dfschema,
+                                input_schema,
+                                session_state,
+                            ),
+                            physical_name(e),
+                        ))
+                    })
+                    .collect::<Result<Vec<_>>>()?,
+            ))
+        }
+    }
+}
+
+/// Expand and align  a GROUPING SET expression.
+/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS)
+///
+/// This will take a list of grouping sets and ensure that each group is
+/// properly aligned for the physical execution plan. We do this by
+/// identifying all unique expression in each group and conforming each
+/// group to the same set of expression types and ordering.
+/// For example, if we have something like `GROUPING SETS ((a,b,c),(a),(b),(b,c))`
+/// we would expand this to `GROUPING SETS ((a,b,c),(a,NULL,NULL),(NULL,b,NULL),(NULL,b,c))
+/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS)
+fn merge_grouping_set_physical_expr(
+    grouping_sets: &[Vec<Expr>],
+    input_dfschema: &DFSchema,
+    input_schema: &Schema,
+    session_state: &SessionState,
+) -> Result<PhysicalGroupBy> {
+    let num_groups = grouping_sets.len();
+    let mut all_exprs: Vec<Expr> = vec![];
+    let mut grouping_set_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
+    let mut null_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
+
+    for expr in grouping_sets.iter().flatten() {
+        if !all_exprs.contains(expr) {
+            all_exprs.push(expr.clone());
+
+            grouping_set_expr.push(get_physical_expr_pair(
+                expr,
+                input_dfschema,
+                input_schema,
+                session_state,
+            )?);
+
+            null_exprs.push(get_null_physical_expr_pair(
+                expr,
+                input_dfschema,
+                input_schema,
+                session_state,
+            )?);
+        }
+    }
+
+    let mut merged_sets: Vec<Vec<bool>> = Vec::with_capacity(num_groups);
+
+    for expr_group in grouping_sets.iter() {
+        let group: Vec<bool> = all_exprs
+            .iter()
+            .map(|expr| !expr_group.contains(expr))
+            .collect();
+
+        merged_sets.push(group)
+    }
+
+    Ok(PhysicalGroupBy::new(
+        grouping_set_expr,
+        null_exprs,
+        merged_sets,
+    ))
+}
+
+/// Expand and align a CUBE expression. This is a special case of GROUPING SETS
+/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS)
+fn create_cube_physical_expr(
+    exprs: &[Expr],
+    input_dfschema: &DFSchema,
+    input_schema: &Schema,
+    session_state: &SessionState,
+) -> Result<PhysicalGroupBy> {
+    let num_of_exprs = exprs.len();
+    let num_groups = num_of_exprs * num_of_exprs;
+
+    let mut null_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
+        Vec::with_capacity(num_of_exprs);
+    let mut all_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
+        Vec::with_capacity(num_of_exprs);
+
+    for expr in exprs {
+        null_exprs.push(get_null_physical_expr_pair(
+            expr,
+            input_dfschema,
+            input_schema,
+            session_state,
+        )?);
+
+        all_exprs.push(get_physical_expr_pair(
+            expr,
+            input_dfschema,
+            input_schema,
+            session_state,
+        )?)
+    }
+
+    let mut groups: Vec<Vec<bool>> = Vec::with_capacity(num_groups);
+
+    groups.push(vec![false; num_of_exprs]);
+
+    for null_count in 1..=num_of_exprs {
+        for null_idx in (0..num_of_exprs).combinations(null_count) {
+            let mut next_group: Vec<bool> = vec![false; num_of_exprs];
+            null_idx.into_iter().for_each(|i| next_group[i] = true);
+            groups.push(next_group);
+        }
+    }
+
+    Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups))
+}
+
+/// Expand and align a ROLLUP expression. This is a special case of GROUPING SETS
+/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS)
+fn create_rollup_physical_expr(
+    exprs: &[Expr],
+    input_dfschema: &DFSchema,
+    input_schema: &Schema,
+    session_state: &SessionState,
+) -> Result<PhysicalGroupBy> {
+    let num_of_exprs = exprs.len();
+
+    let mut null_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
+        Vec::with_capacity(num_of_exprs);
+    let mut all_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
+        Vec::with_capacity(num_of_exprs);
+
+    let mut groups: Vec<Vec<bool>> = Vec::with_capacity(num_of_exprs + 1);
+
+    for expr in exprs {
+        null_exprs.push(get_null_physical_expr_pair(
+            expr,
+            input_dfschema,
+            input_schema,
+            session_state,
+        )?);
+
+        all_exprs.push(get_physical_expr_pair(
+            expr,
+            input_dfschema,
+            input_schema,
+            session_state,
+        )?)
+    }
+
+    for total in 0..=num_of_exprs {
+        let mut group: Vec<bool> = Vec::with_capacity(num_of_exprs);
+
+        for index in 0..num_of_exprs {
+            if index < total {
+                group.push(false);
+            } else {
+                group.push(true);
+            }
+        }
+
+        groups.push(group)
+    }
+
+    Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups))
+}
+
+/// For a given logical expr, get a properly typed NULL ScalarValue physical expression
+fn get_null_physical_expr_pair(
+    expr: &Expr,
+    input_dfschema: &DFSchema,
+    input_schema: &Schema,
+    session_state: &SessionState,
+) -> Result<(Arc<dyn PhysicalExpr>, String)> {
+    let physical_expr = create_physical_expr(
+        expr,
+        input_dfschema,
+        input_schema,
+        &session_state.execution_props,
+    )?;
+    let physical_name = physical_name(&expr.clone())?;
+
+    let data_type = physical_expr.data_type(input_schema)?;
+    let null_value: ScalarValue = (&data_type).try_into()?;
+
+    let null_value = Literal::new(null_value);
+    Ok((Arc::new(null_value), physical_name))
+}
+
+fn get_physical_expr_pair(
+    expr: &Expr,
+    input_dfschema: &DFSchema,
+    input_schema: &Schema,
+    session_state: &SessionState,
+) -> Result<(Arc<dyn PhysicalExpr>, String)> {
+    let physical_expr = create_physical_expr(
+        expr,
+        input_dfschema,
+        input_schema,
+        &session_state.execution_props,
+    )?;
+    let physical_name = physical_name(expr)?;
+    Ok((physical_expr, physical_name))
 }
 
 /// Create a window expression with a name from a logical expression
@@ -1303,6 +1558,7 @@ mod tests {
     };
     use arrow::datatypes::{DataType, Field, SchemaRef};
     use datafusion_common::{DFField, DFSchema, DFSchemaRef};
+    use datafusion_expr::expr::GroupingSet;
     use datafusion_expr::sum;
     use datafusion_expr::{col, lit};
     use fmt::Debug;
@@ -1346,6 +1602,60 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_create_cube_expr() -> Result<()> {
+        let logical_plan = test_csv_scan().await?.build()?;
+
+        let plan = plan(&logical_plan).await?;
+
+        let exprs = vec![col("c1"), col("c2"), col("c3")];
+
+        let physical_input_schema = plan.schema();
+        let physical_input_schema = physical_input_schema.as_ref();
+        let logical_input_schema = logical_plan.schema();
+        let session_state = make_session_state();
+
+        let cube = create_cube_physical_expr(
+            &exprs,
+            logical_input_schema,
+            physical_input_schema,
+            &session_state,
+        );
+
+        let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;
+
+        assert_eq!(format!("{:?}", cube), expected);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_create_rollup_expr() -> Result<()> {
+        let logical_plan = test_csv_scan().await?.build()?;
+
+        let plan = plan(&logical_plan).await?;
+
+        let exprs = vec![col("c1"), col("c2"), col("c3")];
+
+        let physical_input_schema = plan.schema();
+        let physical_input_schema = physical_input_schema.as_ref();
+        let logical_input_schema = logical_plan.schema();
+        let session_state = make_session_state();
+
+        let rollup = create_rollup_physical_expr(
+            &exprs,
+            logical_input_schema,
+            physical_input_schema,
+            &session_state,
+        );
+
+        let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;
+
+        assert_eq!(format!("{:?}", rollup), expected);
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_create_not() -> Result<()> {
         let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
@@ -1620,6 +1930,34 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn hash_agg_grouping_set_input_schema() -> Result<()> {
+        let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+            vec![col("c1")],
+            vec![col("c2")],
+            vec![col("c1"), col("c2")],
+        ]));
+        let logical_plan = test_csv_scan_with_name("aggregate_test_100")
+            .await?
+            .aggregate(vec![grouping_set_expr], vec![sum(col("c3"))])?
+            .build()?;
+
+        let execution_plan = plan(&logical_plan).await?;
+        let final_hash_agg = execution_plan
+            .as_any()
+            .downcast_ref::<AggregateExec>()
+            .expect("hash aggregate");
+        assert_eq!(
+            "SUM(aggregate_test_100.c3)",
+            final_hash_agg.schema().field(2).name()
+        );
+        // we need access to the input to the partial aggregate so that other projects can
+        // implement serde
+        assert_eq!("c3", final_hash_agg.input_schema().field(2).name());
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn hash_agg_group_by_partitioned() -> Result<()> {
         let logical_plan = test_csv_scan()
@@ -1637,6 +1975,28 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn hash_agg_grouping_set_by_partitioned() -> Result<()> {
+        let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+            vec![col("c1")],
+            vec![col("c2")],
+            vec![col("c1"), col("c2")],
+        ]));
+        let logical_plan = test_csv_scan()
+            .await?
+            .aggregate(vec![grouping_set_expr], vec![sum(col("c3"))])?
+            .build()?;
+
+        let execution_plan = plan(&logical_plan).await?;
+        let formatted = format!("{:?}", execution_plan);
+
+        // Make sure the plan contains a FinalPartitioned, which means it will not use the Final
+        // mode in Aggregate (which is slower)
+        assert!(formatted.contains("FinalPartitioned"));
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_explain() {
         let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs
index 38f54a2a7..b25e83cb7 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -24,11 +24,14 @@ use datafusion::from_slice::FromSlice;
 use std::sync::Arc;
 
 use datafusion::assert_batches_eq;
+use datafusion::dataframe::DataFrame;
 use datafusion::error::Result;
 use datafusion::execution::context::SessionContext;
 use datafusion::logical_plan::{col, Expr};
+use datafusion::prelude::CsvReadOptions;
 use datafusion::{datasource::MemTable, prelude::JoinType};
-use datafusion_expr::lit;
+use datafusion_expr::expr::GroupingSet;
+use datafusion_expr::{avg, count, lit, sum};
 
 #[tokio::test]
 async fn join() -> Result<()> {
@@ -207,3 +210,217 @@ async fn select_with_alias_overwrite() -> Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn test_grouping_sets() -> Result<()> {
+    let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+        vec![col("a")],
+        vec![col("b")],
+        vec![col("a"), col("b")],
+    ]));
+
+    let df = create_test_table()?
+        .aggregate(vec![grouping_set_expr], vec![count(col("a"))])?
+        .sort(vec![
+            Expr::Sort {
+                expr: Box::new(col("a")),
+                asc: false,
+                nulls_first: true,
+            },
+            Expr::Sort {
+                expr: Box::new(col("b")),
+                asc: false,
+                nulls_first: true,
+            },
+        ])?;
+
+    let results = df.collect().await?;
+
+    let expected = vec![
+        "+-----------+-----+---------------+",
+        "| a         | b   | COUNT(test.a) |",
+        "+-----------+-----+---------------+",
+        "|           | 100 | 1             |",
+        "|           | 10  | 2             |",
+        "|           | 1   | 1             |",
+        "| abcDEF    |     | 1             |",
+        "| abcDEF    | 1   | 1             |",
+        "| abc123    |     | 1             |",
+        "| abc123    | 10  | 1             |",
+        "| CBAdef    |     | 1             |",
+        "| CBAdef    | 10  | 1             |",
+        "| 123AbcDef |     | 1             |",
+        "| 123AbcDef | 100 | 1             |",
+        "+-----------+-----+---------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_grouping_sets_count() -> Result<()> {
+    let ctx = SessionContext::new();
+
+    let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+        vec![col("c1")],
+        vec![col("c2")],
+    ]));
+
+    let df = aggregates_table(&ctx)
+        .await?
+        .aggregate(vec![grouping_set_expr], vec![count(lit(1))])?
+        .sort(vec![
+            Expr::Sort {
+                expr: Box::new(col("c1")),
+                asc: false,
+                nulls_first: true,
+            },
+            Expr::Sort {
+                expr: Box::new(col("c2")),
+                asc: false,
+                nulls_first: true,
+            },
+        ])?;
+
+    let results = df.collect().await?;
+
+    let expected = vec![
+        "+----+----+-----------------+",
+        "| c1 | c2 | COUNT(Int32(1)) |",
+        "+----+----+-----------------+",
+        "|    | 5  | 14              |",
+        "|    | 4  | 23              |",
+        "|    | 3  | 19              |",
+        "|    | 2  | 22              |",
+        "|    | 1  | 22              |",
+        "| e  |    | 21              |",
+        "| d  |    | 18              |",
+        "| c  |    | 21              |",
+        "| b  |    | 19              |",
+        "| a  |    | 21              |",
+        "+----+----+-----------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_grouping_set_array_agg_with_overflow() -> Result<()> {
+    let ctx = SessionContext::new();
+
+    let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+        vec![col("c1")],
+        vec![col("c2")],
+        vec![col("c1"), col("c2")],
+    ]));
+
+    let df = aggregates_table(&ctx)
+        .await?
+        .aggregate(
+            vec![grouping_set_expr],
+            vec![
+                sum(col("c3")).alias("sum_c3"),
+                avg(col("c3")).alias("avg_c3"),
+            ],
+        )?
+        .sort(vec![
+            Expr::Sort {
+                expr: Box::new(col("c1")),
+                asc: false,
+                nulls_first: true,
+            },
+            Expr::Sort {
+                expr: Box::new(col("c2")),
+                asc: false,
+                nulls_first: true,
+            },
+        ])?;
+
+    let results = df.collect().await?;
+
+    let expected = vec![
+        "+----+----+--------+---------------------+",
+        "| c1 | c2 | sum_c3 | avg_c3              |",
+        "+----+----+--------+---------------------+",
+        "|    | 5  | -194   | -13.857142857142858 |",
+        "|    | 4  | 29     | 1.2608695652173914  |",
+        "|    | 3  | 395    | 20.789473684210527  |",
+        "|    | 2  | 184    | 8.363636363636363   |",
+        "|    | 1  | 367    | 16.681818181818183  |",
+        "| e  |    | 847    | 40.333333333333336  |",
+        "| e  | 5  | -22    | -11                 |",
+        "| e  | 4  | 261    | 37.285714285714285  |",
+        "| e  | 3  | 192    | 48                  |",
+        "| e  | 2  | 189    | 37.8                |",
+        "| e  | 1  | 227    | 75.66666666666667   |",
+        "| d  |    | 458    | 25.444444444444443  |",
+        "| d  | 5  | -99    | -49.5               |",
+        "| d  | 4  | 162    | 54                  |",
+        "| d  | 3  | 124    | 41.333333333333336  |",
+        "| d  | 2  | 328    | 109.33333333333333  |",
+        "| d  | 1  | -57    | -8.142857142857142  |",
+        "| c  |    | -28    | -1.3333333333333333 |",
+        "| c  | 5  | 24     | 12                  |",
+        "| c  | 4  | -43    | -10.75              |",
+        "| c  | 3  | 190    | 47.5                |",
+        "| c  | 2  | -389   | -55.57142857142857  |",
+        "| c  | 1  | 190    | 47.5                |",
+        "| b  |    | -111   | -5.842105263157895  |",
+        "| b  | 5  | -1     | -0.2                |",
+        "| b  | 4  | -223   | -44.6               |",
+        "| b  | 3  | -84    | -42                 |",
+        "| b  | 2  | 102    | 25.5                |",
+        "| b  | 1  | 95     | 31.666666666666668  |",
+        "| a  |    | -385   | -18.333333333333332 |",
+        "| a  | 5  | -96    | -32                 |",
+        "| a  | 4  | -128   | -32                 |",
+        "| a  | 3  | -27    | -4.5                |",
+        "| a  | 2  | -46    | -15.333333333333334 |",
+        "| a  | 1  | -88    | -17.6               |",
+        "+----+----+--------+---------------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+fn create_test_table() -> Result<Arc<DataFrame>> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("a", DataType::Utf8, false),
+        Field::new("b", DataType::Int32, false),
+    ]));
+
+    // define data.
+    let batch = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(StringArray::from_slice(&[
+                "abcDEF",
+                "abc123",
+                "CBAdef",
+                "123AbcDef",
+            ])),
+            Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])),
+        ],
+    )?;
+
+    let ctx = SessionContext::new();
+
+    let table = MemTable::try_new(schema, vec![vec![batch]])?;
+
+    ctx.register_table("test", Arc::new(table))?;
+
+    ctx.table("test")
+}
+
+async fn aggregates_table(ctx: &SessionContext) -> Result<Arc<DataFrame>> {
+    let testdata = datafusion::test_util::arrow_test_data();
+
+    ctx.read_csv(
+        format!("{}/csv/aggregate_test_100.csv", testdata),
+        CsvReadOptions::default(),
+    )
+    .await
+}
diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs
index 08ccbe453..61b1a1afa 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -476,6 +476,205 @@ async fn csv_query_approx_percentile_cont() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_cube_avg() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv_by_sql(&ctx).await;
+
+    let sql = "SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+----+----------------------------+",
+        "| c1 | c2 | AVG(aggregate_test_100.c3) |",
+        "+----+----+----------------------------+",
+        "| a  | 1  | -17.6                      |",
+        "| a  | 2  | -15.333333333333334        |",
+        "| a  | 3  | -4.5                       |",
+        "| a  | 4  | -32                        |",
+        "| a  | 5  | -32                        |",
+        "| a  |    | -18.333333333333332        |",
+        "| b  | 1  | 31.666666666666668         |",
+        "| b  | 2  | 25.5                       |",
+        "| b  | 3  | -42                        |",
+        "| b  | 4  | -44.6                      |",
+        "| b  | 5  | -0.2                       |",
+        "| b  |    | -5.842105263157895         |",
+        "| c  | 1  | 47.5                       |",
+        "| c  | 2  | -55.57142857142857         |",
+        "| c  | 3  | 47.5                       |",
+        "| c  | 4  | -10.75                     |",
+        "| c  | 5  | 12                         |",
+        "| c  |    | -1.3333333333333333        |",
+        "| d  | 1  | -8.142857142857142         |",
+        "| d  | 2  | 109.33333333333333         |",
+        "| d  | 3  | 41.333333333333336         |",
+        "| d  | 4  | 54                         |",
+        "| d  | 5  | -49.5                      |",
+        "| d  |    | 25.444444444444443         |",
+        "| e  | 1  | 75.66666666666667          |",
+        "| e  | 2  | 37.8                       |",
+        "| e  | 3  | 48                         |",
+        "| e  | 4  | 37.285714285714285         |",
+        "| e  | 5  | -11                        |",
+        "| e  |    | 40.333333333333336         |",
+        "|    | 1  | 16.681818181818183         |",
+        "|    | 2  | 8.363636363636363          |",
+        "|    | 3  | 20.789473684210527         |",
+        "|    | 4  | 1.2608695652173914         |",
+        "|    | 5  | -13.857142857142858        |",
+        "|    |    | 7.81                       |",
+        "+----+----+----------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_rollup_avg() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv_by_sql(&ctx).await;
+
+    let sql = "SELECT c1, c2, c3, AVG(c4) FROM aggregate_test_100 GROUP BY ROLLUP (c1, c2, c3) ORDER BY c1, c2, c3";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+----+------+----------------------------+",
+        "| c1 | c2 | c3   | AVG(aggregate_test_100.c4) |",
+        "+----+----+------+----------------------------+",
+        "| a  | 1  | -85  | -15154                     |",
+        "| a  | 1  | -56  | 8692                       |",
+        "| a  | 1  | -25  | 15295                      |",
+        "| a  | 1  | -5   | 12636                      |",
+        "| a  | 1  | 83   | -14704                     |",
+        "| a  | 1  |      | 1353                       |",
+        "| a  | 2  | -48  | -18025                     |",
+        "| a  | 2  | -43  | 13080                      |",
+        "| a  | 2  | 45   | 15673                      |",
+        "| a  | 2  |      | 3576                       |",
+        "| a  | 3  | -72  | -11122                     |",
+        "| a  | 3  | -12  | -9168                      |",
+        "| a  | 3  | 13   | 22338.5                    |",
+        "| a  | 3  | 14   | 28162                      |",
+        "| a  | 3  | 17   | -22796                     |",
+        "| a  | 3  |      | 4958.833333333333          |",
+        "| a  | 4  | -101 | 11640                      |",
+        "| a  | 4  | -54  | -2376                      |",
+        "| a  | 4  | -38  | 20744                      |",
+        "| a  | 4  | 65   | -28462                     |",
+        "| a  | 4  |      | 386.5                      |",
+        "| a  | 5  | -101 | -12484                     |",
+        "| a  | 5  | -31  | -12907                     |",
+        "| a  | 5  | 36   | -16974                     |",
+        "| a  | 5  |      | -14121.666666666666        |",
+        "| a  |    |      | 306.04761904761904         |",
+        "| b  | 1  | 12   | 7652                       |",
+        "| b  | 1  | 29   | -18218                     |",
+        "| b  | 1  | 54   | -18410                     |",
+        "| b  | 1  |      | -9658.666666666666         |",
+        "| b  | 2  | -60  | -21739                     |",
+        "| b  | 2  | 31   | 23127                      |",
+        "| b  | 2  | 63   | 21456                      |",
+        "| b  | 2  | 68   | 15874                      |",
+        "| b  | 2  |      | 9679.5                     |",
+        "| b  | 3  | -101 | -13217                     |",
+        "| b  | 3  | 17   | 14457                      |",
+        "| b  | 3  |      | 620                        |",
+        "| b  | 4  | -117 | 19316                      |",
+        "| b  | 4  | -111 | -1967                      |",
+        "| b  | 4  | -59  | 25286                      |",
+        "| b  | 4  | 17   | -28070                     |",
+        "| b  | 4  | 47   | 20690                      |",
+        "| b  | 4  |      | 7051                       |",
+        "| b  | 5  | -82  | 22080                      |",
+        "| b  | 5  | -44  | 15788                      |",
+        "| b  | 5  | -5   | 24896                      |",
+        "| b  | 5  | 62   | 16337                      |",
+        "| b  | 5  | 68   | 21576                      |",
+        "| b  | 5  |      | 20135.4                    |",
+        "| b  |    |      | 7732.315789473684          |",
+        "| c  | 1  | -24  | -24085                     |",
+        "| c  | 1  | 41   | -4667                      |",
+        "| c  | 1  | 70   | 27752                      |",
+        "| c  | 1  | 103  | -22186                     |",
+        "| c  | 1  |      | -5796.5                    |",
+        "| c  | 2  | -117 | -30187                     |",
+        "| c  | 2  | -107 | -2904                      |",
+        "| c  | 2  | -106 | -1114                      |",
+        "| c  | 2  | -60  | -16312                     |",
+        "| c  | 2  | -29  | 25305                      |",
+        "| c  | 2  | 1    | 18109                      |",
+        "| c  | 2  | 29   | -3855                      |",
+        "| c  | 2  |      | -1565.4285714285713        |",
+        "| c  | 3  | -2   | -18655                     |",
+        "| c  | 3  | 22   | 13741                      |",
+        "| c  | 3  | 73   | -9565                      |",
+        "| c  | 3  | 97   | 29106                      |",
+        "| c  | 3  |      | 3656.75                    |",
+        "| c  | 4  | -90  | -2935                      |",
+        "| c  | 4  | -79  | 5281                       |",
+        "| c  | 4  | 3    | -30508                     |",
+        "| c  | 4  | 123  | 16620                      |",
+        "| c  | 4  |      | -2885.5                    |",
+        "| c  | 5  | -94  | -15880                     |",
+        "| c  | 5  | 118  | 19208                      |",
+        "| c  | 5  |      | 1664                       |",
+        "| c  |    |      | -1320.5238095238096        |",
+        "| d  | 1  | -99  | 5613                       |",
+        "| d  | 1  | -98  | 13630                      |",
+        "| d  | 1  | -72  | 25590                      |",
+        "| d  | 1  | -8   | 27138                      |",
+        "| d  | 1  | 38   | 18384                      |",
+        "| d  | 1  | 57   | 28781                      |",
+        "| d  | 1  | 125  | 31106                      |",
+        "| d  | 1  |      | 21463.14285714286          |",
+        "| d  | 2  | 93   | -12642                     |",
+        "| d  | 2  | 113  | 3917                       |",
+        "| d  | 2  | 122  | 10130                      |",
+        "| d  | 2  |      | 468.3333333333333          |",
+        "| d  | 3  | -76  | 8809                       |",
+        "| d  | 3  | 77   | 15091                      |",
+        "| d  | 3  | 123  | 29533                      |",
+        "| d  | 3  |      | 17811                      |",
+        "| d  | 4  | 5    | -7688                      |",
+        "| d  | 4  | 55   | -1471                      |",
+        "| d  | 4  | 102  | -24558                     |",
+        "| d  | 4  |      | -11239                     |",
+        "| d  | 5  | -59  | 2045                       |",
+        "| d  | 5  | -40  | 22614                      |",
+        "| d  | 5  |      | 12329.5                    |",
+        "| d  |    |      | 10890.111111111111         |",
+        "| e  | 1  | 36   | -21481                     |",
+        "| e  | 1  | 71   | -5479                      |",
+        "| e  | 1  | 120  | 10837                      |",
+        "| e  | 1  |      | -5374.333333333333         |",
+        "| e  | 2  | -61  | -2888                      |",
+        "| e  | 2  | 49   | 24495                      |",
+        "| e  | 2  | 52   | 5666                       |",
+        "| e  | 2  | 97   | 18167                      |",
+        "| e  | 2  |      | 10221.2                    |",
+        "| e  | 3  | -95  | 13611                      |",
+        "| e  | 3  | 71   | 194                        |",
+        "| e  | 3  | 104  | -25136                     |",
+        "| e  | 3  | 112  | -6823                      |",
+        "| e  | 3  |      | -4538.5                    |",
+        "| e  | 4  | -56  | -31500                     |",
+        "| e  | 4  | -53  | 13788                      |",
+        "| e  | 4  | 30   | -16110                     |",
+        "| e  | 4  | 73   | -22501                     |",
+        "| e  | 4  | 74   | -12612                     |",
+        "| e  | 4  | 96   | -30336                     |",
+        "| e  | 4  | 97   | -13181                     |",
+        "| e  | 4  |      | -16064.57142857143         |",
+        "| e  | 5  | -86  | 32514                      |",
+        "| e  | 5  | 64   | -26526                     |",
+        "| e  | 5  |      | 2994                       |",
+        "| e  |    |      | -4268.333333333333         |",
+        "|    |    |      | 2319.97                    |",
+        "+----+----+------+----------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
 #[tokio::test]
 async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> {
     let ctx = SessionContext::new();
@@ -583,6 +782,200 @@ async fn csv_query_sum_crossjoin() {
     assert_batches_eq!(expected, &actual);
 }
 
+#[tokio::test]
+async fn csv_query_cube_sum_crossjoin() {
+    let ctx = SessionContext::new();
+    register_aggregate_csv_by_sql(&ctx).await;
+    let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY CUBE (a.c1, b.c1) ORDER BY a.c1, b.c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+----+-----------+",
+        "| c1 | c1 | SUM(a.c2) |",
+        "+----+----+-----------+",
+        "| a  | a  | 1260      |",
+        "| a  | b  | 1140      |",
+        "| a  | c  | 1260      |",
+        "| a  | d  | 1080      |",
+        "| a  | e  | 1260      |",
+        "| a  |    | 6000      |",
+        "| b  | a  | 1302      |",
+        "| b  | b  | 1178      |",
+        "| b  | c  | 1302      |",
+        "| b  | d  | 1116      |",
+        "| b  | e  | 1302      |",
+        "| b  |    | 6200      |",
+        "| c  | a  | 1176      |",
+        "| c  | b  | 1064      |",
+        "| c  | c  | 1176      |",
+        "| c  | d  | 1008      |",
+        "| c  | e  | 1176      |",
+        "| c  |    | 5600      |",
+        "| d  | a  | 924       |",
+        "| d  | b  | 836       |",
+        "| d  | c  | 924       |",
+        "| d  | d  | 792       |",
+        "| d  | e  | 924       |",
+        "| d  |    | 4400      |",
+        "| e  | a  | 1323      |",
+        "| e  | b  | 1197      |",
+        "| e  | c  | 1323      |",
+        "| e  | d  | 1134      |",
+        "| e  | e  | 1323      |",
+        "| e  |    | 6300      |",
+        "|    | a  | 5985      |",
+        "|    | b  | 5415      |",
+        "|    | c  | 5985      |",
+        "|    | d  | 5130      |",
+        "|    | e  | 5985      |",
+        "|    |    | 28500     |",
+        "+----+----+-----------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
+async fn csv_query_cube_distinct_count() {
+    let ctx = SessionContext::new();
+    register_aggregate_csv_by_sql(&ctx).await;
+    let sql = "SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY CUBE (c1,c2) ORDER BY c1,c2";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+----+---------------------------------------+",
+        "| c1 | c2 | COUNT(DISTINCT aggregate_test_100.c3) |",
+        "+----+----+---------------------------------------+",
+        "| a  | 1  | 5                                     |",
+        "| a  | 2  | 3                                     |",
+        "| a  | 3  | 5                                     |",
+        "| a  | 4  | 4                                     |",
+        "| a  | 5  | 3                                     |",
+        "| a  |    | 19                                    |",
+        "| b  | 1  | 3                                     |",
+        "| b  | 2  | 4                                     |",
+        "| b  | 3  | 2                                     |",
+        "| b  | 4  | 5                                     |",
+        "| b  | 5  | 5                                     |",
+        "| b  |    | 17                                    |",
+        "| c  | 1  | 4                                     |",
+        "| c  | 2  | 7                                     |",
+        "| c  | 3  | 4                                     |",
+        "| c  | 4  | 4                                     |",
+        "| c  | 5  | 2                                     |",
+        "| c  |    | 21                                    |",
+        "| d  | 1  | 7                                     |",
+        "| d  | 2  | 3                                     |",
+        "| d  | 3  | 3                                     |",
+        "| d  | 4  | 3                                     |",
+        "| d  | 5  | 2                                     |",
+        "| d  |    | 18                                    |",
+        "| e  | 1  | 3                                     |",
+        "| e  | 2  | 4                                     |",
+        "| e  | 3  | 4                                     |",
+        "| e  | 4  | 7                                     |",
+        "| e  | 5  | 2                                     |",
+        "| e  |    | 18                                    |",
+        "|    | 1  | 22                                    |",
+        "|    | 2  | 20                                    |",
+        "|    | 3  | 17                                    |",
+        "|    | 4  | 23                                    |",
+        "|    | 5  | 14                                    |",
+        "|    |    | 80                                    |",
+        "+----+----+---------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
+async fn csv_query_rollup_distinct_count() {
+    let ctx = SessionContext::new();
+    register_aggregate_csv_by_sql(&ctx).await;
+    let sql = "SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY ROLLUP (c1,c2) ORDER BY c1,c2";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+----+---------------------------------------+",
+        "| c1 | c2 | COUNT(DISTINCT aggregate_test_100.c3) |",
+        "+----+----+---------------------------------------+",
+        "| a  | 1  | 5                                     |",
+        "| a  | 2  | 3                                     |",
+        "| a  | 3  | 5                                     |",
+        "| a  | 4  | 4                                     |",
+        "| a  | 5  | 3                                     |",
+        "| a  |    | 19                                    |",
+        "| b  | 1  | 3                                     |",
+        "| b  | 2  | 4                                     |",
+        "| b  | 3  | 2                                     |",
+        "| b  | 4  | 5                                     |",
+        "| b  | 5  | 5                                     |",
+        "| b  |    | 17                                    |",
+        "| c  | 1  | 4                                     |",
+        "| c  | 2  | 7                                     |",
+        "| c  | 3  | 4                                     |",
+        "| c  | 4  | 4                                     |",
+        "| c  | 5  | 2                                     |",
+        "| c  |    | 21                                    |",
+        "| d  | 1  | 7                                     |",
+        "| d  | 2  | 3                                     |",
+        "| d  | 3  | 3                                     |",
+        "| d  | 4  | 3                                     |",
+        "| d  | 5  | 2                                     |",
+        "| d  |    | 18                                    |",
+        "| e  | 1  | 3                                     |",
+        "| e  | 2  | 4                                     |",
+        "| e  | 3  | 4                                     |",
+        "| e  | 4  | 7                                     |",
+        "| e  | 5  | 2                                     |",
+        "| e  |    | 18                                    |",
+        "|    |    | 80                                    |",
+        "+----+----+---------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+}
+
+#[tokio::test]
+async fn csv_query_rollup_sum_crossjoin() {
+    let ctx = SessionContext::new();
+    register_aggregate_csv_by_sql(&ctx).await;
+    let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY ROLLUP (a.c1, b.c1) ORDER BY a.c1, b.c1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+----+-----------+",
+        "| c1 | c1 | SUM(a.c2) |",
+        "+----+----+-----------+",
+        "| a  | a  | 1260      |",
+        "| a  | b  | 1140      |",
+        "| a  | c  | 1260      |",
+        "| a  | d  | 1080      |",
+        "| a  | e  | 1260      |",
+        "| a  |    | 6000      |",
+        "| b  | a  | 1302      |",
+        "| b  | b  | 1178      |",
+        "| b  | c  | 1302      |",
+        "| b  | d  | 1116      |",
+        "| b  | e  | 1302      |",
+        "| b  |    | 6200      |",
+        "| c  | a  | 1176      |",
+        "| c  | b  | 1064      |",
+        "| c  | c  | 1176      |",
+        "| c  | d  | 1008      |",
+        "| c  | e  | 1176      |",
+        "| c  |    | 5600      |",
+        "| d  | a  | 924       |",
+        "| d  | b  | 836       |",
+        "| d  | c  | 924       |",
+        "| d  | d  | 792       |",
+        "| d  | e  | 924       |",
+        "| d  |    | 4400      |",
+        "| e  | a  | 1323      |",
+        "| e  | b  | 1197      |",
+        "| e  | c  | 1323      |",
+        "| e  | d  | 1134      |",
+        "| e  | e  | 1323      |",
+        "| e  |    | 6300      |",
+        "|    |    | 28500     |",
+        "+----+----+-----------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+}
+
 #[tokio::test]
 async fn query_count_without_from() -> Result<()> {
     let ctx = SessionContext::new();
@@ -675,6 +1068,59 @@ async fn csv_query_array_agg_with_overflow() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_array_cube_agg_with_overflow() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql =
+        "select c1, c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by CUBE (c1,c2) order by c1, c2";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+----+--------+---------------------+--------+--------+----------+",
+        "| c1 | c2 | sum_c3 | avg_c3              | max_c3 | min_c3 | count_c3 |",
+        "+----+----+--------+---------------------+--------+--------+----------+",
+        "| a  | 1  | -88    | -17.6               | 83     | -85    | 5        |",
+        "| a  | 2  | -46    | -15.333333333333334 | 45     | -48    | 3        |",
+        "| a  | 3  | -27    | -4.5                | 17     | -72    | 6        |",
+        "| a  | 4  | -128   | -32                 | 65     | -101   | 4        |",
+        "| a  | 5  | -96    | -32                 | 36     | -101   | 3        |",
+        "| a  |    | -385   | -18.333333333333332 | 83     | -101   | 21       |",
+        "| b  | 1  | 95     | 31.666666666666668  | 54     | 12     | 3        |",
+        "| b  | 2  | 102    | 25.5                | 68     | -60    | 4        |",
+        "| b  | 3  | -84    | -42                 | 17     | -101   | 2        |",
+        "| b  | 4  | -223   | -44.6               | 47     | -117   | 5        |",
+        "| b  | 5  | -1     | -0.2                | 68     | -82    | 5        |",
+        "| b  |    | -111   | -5.842105263157895  | 68     | -117   | 19       |",
+        "| c  | 1  | 190    | 47.5                | 103    | -24    | 4        |",
+        "| c  | 2  | -389   | -55.57142857142857  | 29     | -117   | 7        |",
+        "| c  | 3  | 190    | 47.5                | 97     | -2     | 4        |",
+        "| c  | 4  | -43    | -10.75              | 123    | -90    | 4        |",
+        "| c  | 5  | 24     | 12                  | 118    | -94    | 2        |",
+        "| c  |    | -28    | -1.3333333333333333 | 123    | -117   | 21       |",
+        "| d  | 1  | -57    | -8.142857142857142  | 125    | -99    | 7        |",
+        "| d  | 2  | 328    | 109.33333333333333  | 122    | 93     | 3        |",
+        "| d  | 3  | 124    | 41.333333333333336  | 123    | -76    | 3        |",
+        "| d  | 4  | 162    | 54                  | 102    | 5      | 3        |",
+        "| d  | 5  | -99    | -49.5               | -40    | -59    | 2        |",
+        "| d  |    | 458    | 25.444444444444443  | 125    | -99    | 18       |",
+        "| e  | 1  | 227    | 75.66666666666667   | 120    | 36     | 3        |",
+        "| e  | 2  | 189    | 37.8                | 97     | -61    | 5        |",
+        "| e  | 3  | 192    | 48                  | 112    | -95    | 4        |",
+        "| e  | 4  | 261    | 37.285714285714285  | 97     | -56    | 7        |",
+        "| e  | 5  | -22    | -11                 | 64     | -86    | 2        |",
+        "| e  |    | 847    | 40.333333333333336  | 120    | -95    | 21       |",
+        "|    | 1  | 367    | 16.681818181818183  | 125    | -99    | 22       |",
+        "|    | 2  | 184    | 8.363636363636363   | 122    | -117   | 22       |",
+        "|    | 3  | 395    | 20.789473684210527  | 123    | -101   | 19       |",
+        "|    | 4  | 29     | 1.2608695652173914  | 123    | -117   | 23       |",
+        "|    | 5  | -194   | -13.857142857142858 | 118    | -101   | 14       |",
+        "|    |    | 781    | 7.81                | 125    | -117   | 100      |",
+        "+----+----+--------+---------------------+--------+--------+----------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
 #[tokio::test]
 async fn csv_query_array_agg_distinct() -> Result<()> {
     let ctx = SessionContext::new();
@@ -1223,6 +1669,79 @@ async fn count_aggregated() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn count_aggregated_cube() -> Result<()> {
+    let results = execute_with_partition(
+        "SELECT c1, c2, COUNT(c3) FROM test GROUP BY CUBE (c1, c2) ORDER BY c1, c2",
+        4,
+    )
+    .await?;
+
+    let expected = vec![
+        "+----+----+----------------+",
+        "| c1 | c2 | COUNT(test.c3) |",
+        "+----+----+----------------+",
+        "|    |    | 40             |",
+        "|    | 1  | 4              |",
+        "|    | 10 | 4              |",
+        "|    | 2  | 4              |",
+        "|    | 3  | 4              |",
+        "|    | 4  | 4              |",
+        "|    | 5  | 4              |",
+        "|    | 6  | 4              |",
+        "|    | 7  | 4              |",
+        "|    | 8  | 4              |",
+        "|    | 9  | 4              |",
+        "| 0  |    | 10             |",
+        "| 0  | 1  | 1              |",
+        "| 0  | 10 | 1              |",
+        "| 0  | 2  | 1              |",
+        "| 0  | 3  | 1              |",
+        "| 0  | 4  | 1              |",
+        "| 0  | 5  | 1              |",
+        "| 0  | 6  | 1              |",
+        "| 0  | 7  | 1              |",
+        "| 0  | 8  | 1              |",
+        "| 0  | 9  | 1              |",
+        "| 1  |    | 10             |",
+        "| 1  | 1  | 1              |",
+        "| 1  | 10 | 1              |",
+        "| 1  | 2  | 1              |",
+        "| 1  | 3  | 1              |",
+        "| 1  | 4  | 1              |",
+        "| 1  | 5  | 1              |",
+        "| 1  | 6  | 1              |",
+        "| 1  | 7  | 1              |",
+        "| 1  | 8  | 1              |",
+        "| 1  | 9  | 1              |",
+        "| 2  |    | 10             |",
+        "| 2  | 1  | 1              |",
+        "| 2  | 10 | 1              |",
+        "| 2  | 2  | 1              |",
+        "| 2  | 3  | 1              |",
+        "| 2  | 4  | 1              |",
+        "| 2  | 5  | 1              |",
+        "| 2  | 6  | 1              |",
+        "| 2  | 7  | 1              |",
+        "| 2  | 8  | 1              |",
+        "| 2  | 9  | 1              |",
+        "| 3  |    | 10             |",
+        "| 3  | 1  | 1              |",
+        "| 3  | 10 | 1              |",
+        "| 3  | 2  | 1              |",
+        "| 3  | 3  | 1              |",
+        "| 3  | 4  | 1              |",
+        "| 3  | 5  | 1              |",
+        "| 3  | 6  | 1              |",
+        "| 3  | 7  | 1              |",
+        "| 3  | 8  | 1              |",
+        "| 3  | 9  | 1              |",
+        "+----+----+----------------+",
+    ];
+    assert_batches_sorted_eq!(expected, &results);
+    Ok(())
+}
+
 #[tokio::test]
 async fn simple_avg() -> Result<()> {
     let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 7cf161697..202605b4d 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -270,6 +270,27 @@ pub enum GroupingSet {
     GroupingSets(Vec<Vec<Expr>>),
 }
 
+impl GroupingSet {
+    /// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP` this
+    /// is just the underlying list of exprs. For `GROUPING SET` we need to deduplicate
+    /// the exprs in the underlying sets.
+    pub fn distinct_expr(&self) -> Vec<Expr> {
+        match self {
+            GroupingSet::Rollup(exprs) => exprs.clone(),
+            GroupingSet::Cube(exprs) => exprs.clone(),
+            GroupingSet::GroupingSets(groups) => {
+                let mut exprs: Vec<Expr> = vec![];
+                for exp in groups.iter().flatten() {
+                    if !exprs.contains(exp) {
+                        exprs.push(exp.clone());
+                    }
+                }
+                exprs
+            }
+        }
+    }
+}
+
 /// Fixed seed for the hashing so that Ords are consistent across runs
 const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);
 
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 14eeb2c82..76bd9a975 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -17,6 +17,7 @@
 
 //! Functions for creating logical expressions
 
+use crate::expr::GroupingSet;
 use crate::{
     aggregate_function, built_in_function, conditional_expressions::CaseBuilder, lit,
     logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
@@ -226,6 +227,21 @@ pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr {
     Expr::ScalarSubquery(Subquery { subquery })
 }
 
+/// Create a grouping set
+pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
+    Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
+}
+
+/// Create a grouping set for all combination of `exprs`
+pub fn cube(exprs: Vec<Expr>) -> Expr {
+    Expr::GroupingSet(GroupingSet::Cube(exprs))
+}
+
+/// Create a grouping set for rollup
+pub fn rollup(exprs: Vec<Expr>) -> Expr {
+    Expr::GroupingSet(GroupingSet::Rollup(exprs))
+}
+
 // TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
 // varying arity functions
 /// Create an convenience function representing a unary scalar function
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 8d58241ea..083b66c32 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -18,7 +18,9 @@
 //! This module provides a builder for creating LogicalPlans
 
 use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs};
-use crate::utils::{columnize_expr, exprlist_to_fields, from_plan};
+use crate::utils::{
+    columnize_expr, exprlist_to_fields, from_plan, grouping_set_to_exprlist,
+};
 use crate::{and, binary_expr, Operator};
 use crate::{
     logical_plan::{
@@ -694,7 +696,10 @@ impl LogicalPlanBuilder {
     ) -> Result<Self> {
         let group_expr = normalize_cols(group_expr, &self.plan)?;
         let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
-        let all_expr = group_expr.iter().chain(aggr_expr.iter());
+
+        let grouping_expr: Vec<Expr> = grouping_set_to_exprlist(group_expr.as_slice())?;
+
+        let all_expr = grouping_expr.iter().chain(aggr_expr.iter());
         validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?;
         let aggr_schema = DFSchema::new_with_metadata(
             exprlist_to_fields(all_expr, &self.plan)?,
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 2120acaed..a85a817a8 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -45,6 +45,22 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result
     Ok(())
 }
 
+/// Find all distinct exprs in a list of group by expressions. If the
+/// first element is a `GroupingSet` expression then it must be the only expr.
+pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
+    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
+        if group_expr.len() > 1 {
+            return Err(DataFusionError::Plan(
+                "Invalid group by expressions, GroupingSet must be the only expression"
+                    .to_string(),
+            ));
+        }
+        Ok(grouping_set.distinct_expr())
+    } else {
+        Ok(group_expr.to_vec())
+    }
+}
+
 /// Recursively walk an expression tree, collecting the unique set of column names
 /// referenced in the expression
 struct ColumnNameVisitor<'a> {
diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/projection_push_down.rs
index c9aee1e03..ae2cc4fce 100644
--- a/datafusion/optimizer/src/projection_push_down.rs
+++ b/datafusion/optimizer/src/projection_push_down.rs
@@ -24,6 +24,7 @@ use arrow::error::Result as ArrowResult;
 use datafusion_common::{
     Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ToDFSchema,
 };
+use datafusion_expr::utils::grouping_set_to_exprlist;
 use datafusion_expr::{
     logical_plan::{
         builder::{build_join_schema, LogicalPlanBuilder},
@@ -314,7 +315,10 @@ fn optimize_plan(
             // * remove any aggregate expression that is not required
             // * construct the new set of required columns
 
-            exprlist_to_columns(group_expr, &mut new_required_columns)?;
+            // Find distinct group by exprs in the case where we have a grouping set
+            let all_group_expr: Vec<Expr> = grouping_set_to_exprlist(group_expr)?;
+
+            exprlist_to_columns(&all_group_expr, &mut new_required_columns)?;
 
             // Gather all columns needed for expressions in this Aggregate
             let mut new_aggr_expr = Vec::new();
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index c508b9772..80214f302 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -19,6 +19,7 @@
 
 use crate::{OptimizerConfig, OptimizerRule};
 use datafusion_common::{DFSchema, Result};
+use datafusion_expr::utils::grouping_set_to_exprlist;
 use datafusion_expr::{
     col,
     logical_plan::{Aggregate, LogicalPlan, Projection},
@@ -62,9 +63,11 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
             schema,
             group_expr,
         }) => {
-            if is_single_distinct_agg(plan) {
+            if is_single_distinct_agg(plan) && !contains_grouping_set(group_expr) {
                 let mut group_fields_set = HashSet::new();
-                let mut all_group_args = group_expr.clone();
+                let base_group_expr = grouping_set_to_exprlist(group_expr)?;
+                let mut all_group_args: Vec<Expr> = group_expr.clone();
+
                 // remove distinct and collection args
                 let new_aggr_expr = aggr_expr
                     .iter()
@@ -87,7 +90,9 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
                     })
                     .collect::<Vec<_>>();
 
-                let all_field = all_group_args
+                let all_group_expr = grouping_set_to_exprlist(&all_group_args)?;
+
+                let all_field = all_group_expr
                     .iter()
                     .map(|expr| expr.to_field(input.schema()).unwrap())
                     .collect::<Vec<_>>();
@@ -106,7 +111,7 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
                 let grouped_agg = optimize_children(&grouped_agg);
                 let final_agg_schema = Arc::new(
                     DFSchema::new_with_metadata(
-                        group_expr
+                        base_group_expr
                             .iter()
                             .chain(new_aggr_expr.iter())
                             .map(|expr| expr.to_field(&grouped_schema).unwrap())
@@ -115,18 +120,12 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
                     )
                     .unwrap(),
                 );
-                let final_agg = LogicalPlan::Aggregate(Aggregate {
-                    input: Arc::new(grouped_agg.unwrap()),
-                    group_expr: group_expr.clone(),
-                    aggr_expr: new_aggr_expr,
-                    schema: final_agg_schema.clone(),
-                });
 
                 // so the aggregates are displayed in the same way even after the rewrite
                 let mut alias_expr: Vec<Expr> = Vec::new();
-                final_agg
-                    .expressions()
+                base_group_expr
                     .iter()
+                    .chain(new_aggr_expr.iter())
                     .enumerate()
                     .for_each(|(i, field)| {
                         alias_expr.push(columnize_expr(
@@ -135,11 +134,18 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
                         ));
                     });
 
+                let final_agg = LogicalPlan::Aggregate(Aggregate {
+                    input: Arc::new(grouped_agg.unwrap()),
+                    group_expr: group_expr.clone(),
+                    aggr_expr: new_aggr_expr,
+                    schema: final_agg_schema,
+                });
+
                 Ok(LogicalPlan::Projection(Projection {
                     expr: alias_expr,
                     input: Arc::new(final_agg),
                     schema: schema.clone(),
-                    alias: Option::None,
+                    alias: None,
                 }))
             } else {
                 optimize_children(plan)
@@ -185,6 +191,10 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> bool {
     }
 }
 
+fn contains_grouping_set(expr: &[Expr]) -> bool {
+    matches!(expr.first(), Some(Expr::GroupingSet(_)))
+}
+
 impl OptimizerRule for SingleDistinctToGroupBy {
     fn optimize(
         &self,
@@ -202,6 +212,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
 mod tests {
     use super::*;
     use crate::test::*;
+    use datafusion_expr::expr::GroupingSet;
     use datafusion_expr::{
         col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max,
         AggregateFunction,
@@ -212,6 +223,7 @@ mod tests {
         let optimized_plan = rule
             .optimize(plan, &OptimizerConfig::new())
             .expect("failed to optimize plan");
+
         let formatted_plan = format!("{}", optimized_plan.display_indent_schema());
         assert_eq!(formatted_plan, expected);
     }
@@ -250,6 +262,69 @@ mod tests {
         Ok(())
     }
 
+    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
+    #[test]
+    fn single_distinct_and_grouping_set() -> Result<()> {
+        let table_scan = test_table_scan()?;
+
+        let grouping_set = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+            vec![col("a")],
+            vec![col("b")],
+        ]));
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
+            .build()?;
+
+        // Should not be optimized
+        let expected = "Aggregate: groupBy=[[GROUPING SETS ((#test.a), (#test.b))]], aggr=[[COUNT(DISTINCT #test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
+                            \n  TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
+    #[test]
+    fn single_distinct_and_cube() -> Result<()> {
+        let table_scan = test_table_scan()?;
+
+        let grouping_set = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
+            .build()?;
+
+        println!("{:?}", plan);
+
+        // Should not be optimized
+        let expected = "Aggregate: groupBy=[[CUBE (#test.a, #test.b)]], aggr=[[COUNT(DISTINCT #test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
+                            \n  TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET
+    #[test]
+    fn single_distinct_and_rollup() -> Result<()> {
+        let table_scan = test_table_scan()?;
+
+        let grouping_set =
+            Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));
+
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])?
+            .build()?;
+
+        // Should not be optimized
+        let expected = "Aggregate: groupBy=[[ROLLUP (#test.a, #test.b)]], aggr=[[COUNT(DISTINCT #test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\
+                            \n  TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
     #[test]
     fn single_distinct_expr() -> Result<()> {
         let table_scan = test_table_scan()?;
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 4522cd63e..dffc8ec2f 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -300,9 +300,33 @@ message LogicalExprNode {
     ScalarUDFExprNode scalar_udf_expr = 20;
 
     GetIndexedField get_indexed_field = 21;
+
+    GroupingSetNode grouping_set = 22;
+
+    CubeNode cube = 23;
+
+    RollupNode rollup = 24;
   }
 }
 
+message LogicalExprList {
+  repeated LogicalExprNode expr = 1;
+}
+
+message GroupingSetNode {
+  repeated LogicalExprList expr = 1;
+}
+
+message CubeNode {
+  repeated LogicalExprNode expr = 1;
+}
+
+message RollupNode {
+  repeated LogicalExprNode expr = 1;
+}
+
+
+
 message GetIndexedField {
   LogicalExprNode expr = 1;
   ScalarValue key = 2;
diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs
index c684b785e..279cb8e40 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -20,12 +20,17 @@ use crate::protobuf::plan_type::PlanTypeEnum::{
     FinalLogicalPlan, FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan,
     OptimizedLogicalPlan, OptimizedPhysicalPlan,
 };
-use crate::protobuf::{OptimizedLogicalPlanType, OptimizedPhysicalPlanType};
+use crate::protobuf::{
+    CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType,
+    RollupNode,
+};
 use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode};
 use datafusion::logical_plan::FunctionRegistry;
 use datafusion_common::{
     Column, DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue,
 };
+use datafusion_expr::expr::GroupingSet;
+use datafusion_expr::expr::GroupingSet::GroupingSets;
 use datafusion_expr::{
     abs, acos, array, ascii, asin, atan, bit_length, btrim, ceil, character_length, chr,
     coalesce, concat_expr, concat_ws_expr, cos, date_part, date_trunc, digest, exp,
@@ -1290,6 +1295,32 @@ pub fn parse_expr(
                     .collect::<Result<Vec<_>, Error>>()?,
             })
         }
+
+        ExprType::GroupingSet(GroupingSetNode { expr }) => {
+            Ok(Expr::GroupingSet(GroupingSets(
+                expr.iter()
+                    .map(|expr_list| {
+                        expr_list
+                            .expr
+                            .iter()
+                            .map(|expr| parse_expr(expr, registry))
+                            .collect::<Result<Vec<_>, Error>>()
+                    })
+                    .collect::<Result<Vec<_>, Error>>()?,
+            )))
+        }
+        ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube(
+            expr.iter()
+                .map(|expr| parse_expr(expr, registry))
+                .collect::<Result<Vec<_>, Error>>()?,
+        ))),
+        ExprType::Rollup(RollupNode { expr }) => {
+            Ok(Expr::GroupingSet(GroupingSet::Rollup(
+                expr.iter()
+                    .map(|expr| parse_expr(expr, registry))
+                    .collect::<Result<Vec<_>, Error>>()?,
+            )))
+        }
     }
 }
 
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index 6fe1aac68..f08a00b49 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -62,6 +62,7 @@ mod roundtrip_tests {
     use datafusion::physical_plan::functions::make_scalar_function;
     use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext};
     use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
+    use datafusion_expr::expr::GroupingSet;
     use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
     use datafusion_expr::{
         col, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, Expr,
@@ -1001,4 +1002,32 @@ mod roundtrip_tests {
 
         roundtrip_expr_test!(test_expr, ctx);
     }
+
+    #[test]
+    fn roundtrip_grouping_sets() {
+        let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+            vec![col("a")],
+            vec![col("b")],
+            vec![col("a"), col("b")],
+        ]));
+
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
+    }
+
+    #[test]
+    fn roundtrip_rollup() {
+        let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")]));
+
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
+    }
+
+    #[test]
+    fn roundtrip_cube() {
+        let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")]));
+
+        let ctx = SessionContext::new();
+        roundtrip_expr_test!(test_expr, ctx);
+    }
 }
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 8df8ff0dd..afe24ea89 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -25,12 +25,14 @@ use crate::protobuf::{
         FinalLogicalPlan, FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan,
         OptimizedLogicalPlan, OptimizedPhysicalPlan,
     },
-    EmptyMessage, OptimizedLogicalPlanType, OptimizedPhysicalPlanType,
+    CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType,
+    OptimizedPhysicalPlanType, RollupNode,
 };
 use arrow::datatypes::{
     DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode,
 };
 use datafusion_common::{Column, DFField, DFSchemaRef, ScalarValue};
+use datafusion_expr::expr::GroupingSet;
 use datafusion_expr::{
     logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
     BuiltInWindowFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFrameBound,
@@ -718,9 +720,42 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                     },
                 ))),
             },
-            Expr::QualifiedWildcard { .. }
-            | Expr::TryCast { .. }
-            | Expr::GroupingSet(_) => unimplemented!(),
+
+            Expr::GroupingSet(GroupingSet::Cube(exprs)) => Self {
+                expr_type: Some(ExprType::Cube(CubeNode {
+                    expr: exprs.iter().map(|expr| expr.try_into()).collect::<Result<
+                        Vec<_>,
+                        Self::Error,
+                    >>(
+                    )?,
+                })),
+            },
+            Expr::GroupingSet(GroupingSet::Rollup(exprs)) => Self {
+                expr_type: Some(ExprType::Rollup(RollupNode {
+                    expr: exprs.iter().map(|expr| expr.try_into()).collect::<Result<
+                        Vec<_>,
+                        Self::Error,
+                    >>(
+                    )?,
+                })),
+            },
+            Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => Self {
+                expr_type: Some(ExprType::GroupingSet(GroupingSetNode {
+                    expr: exprs
+                        .iter()
+                        .map(|expr_list| {
+                            Ok(LogicalExprList {
+                                expr: expr_list
+                                    .iter()
+                                    .map(|expr| expr.try_into())
+                                    .collect::<Result<Vec<_>, Self::Error>>()?,
+                            })
+                        })
+                        .collect::<Result<Vec<_>, Self::Error>>()?,
+                })),
+            },
+
+            Expr::QualifiedWildcard { .. } | Expr::TryCast { .. } => unimplemented!(),
         };
 
         Ok(expr_node)