You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/01/16 15:41:43 UTC

flink git commit: [FLINK-5303] [table] Add CUBE/ROLLUP/GROUPING SETS operator in SQL

Repository: flink
Updated Branches:
  refs/heads/master 68228ffdd -> ef8cdfe59


[FLINK-5303] [table] Add CUBE/ROLLUP/GROUPING SETS operator in SQL

This closes #2976.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ef8cdfe5
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ef8cdfe5
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ef8cdfe5

Branch: refs/heads/master
Commit: ef8cdfe5930201f79c78f34cc9f462b4e88b3da1
Parents: 68228ff
Author: Aleksandr Chermenin <al...@epam.com>
Authored: Wed Dec 7 10:57:04 2016 +0300
Committer: twalthr <tw...@apache.org>
Committed: Mon Jan 16 16:40:01 2017 +0100

----------------------------------------------------------------------
 docs/dev/table_api.md                           |  50 ++++-
 .../table/expressions/ExpressionParser.scala    |   5 +-
 .../plan/nodes/dataset/DataSetAggregate.scala   |  11 +-
 .../nodes/datastream/DataStreamAggregate.scala  |   3 +-
 .../rules/dataSet/DataSetAggregateRule.scala    |  67 ++++--
 .../DataSetAggregateWithNullValuesRule.scala    |  21 +-
 .../datastream/DataStreamAggregateRule.scala    |   4 +-
 .../aggregate/AggregateMapFunction.scala        |   6 +-
 .../AggregateReduceCombineFunction.scala        |  15 +-
 .../AggregateReduceGroupFunction.scala          |  29 ++-
 .../table/runtime/aggregate/AggregateUtil.scala |  71 +++++-
 .../flink/table/validate/FunctionCatalog.scala  |   4 +
 .../api/java/batch/sql/GroupingSetsITCase.java  | 219 +++++++++++++++++++
 .../scala/batch/sql/AggregationsITCase.scala    |  20 +-
 .../api/scala/batch/sql/GroupingSetsTest.scala  | 206 +++++++++++++++++
 15 files changed, 654 insertions(+), 77 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/docs/dev/table_api.md
----------------------------------------------------------------------
diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md
index 0fcd88d..acabfcf 100644
--- a/docs/dev/table_api.md
+++ b/docs/dev/table_api.md
@@ -1334,7 +1334,7 @@ Among others, the following SQL features are not supported, yet:
 - Interval arithmetic is currenly limited
 - Distinct aggregates (e.g., `COUNT(DISTINCT name)`)
 - Non-equi joins and Cartesian products
-- Grouping sets
+- Efficient grouping sets
 
 *Note: Tables are joined in the order in which they are specified in the `FROM` clause. In some cases the table order must be manually tweaked to resolve Cartesian products.*
 
@@ -1442,7 +1442,9 @@ groupItem:
   expression
   | '(' ')'
   | '(' expression [, expression ]* ')'
-
+  | CUBE '(' expression [, expression ]* ')'
+  | ROLLUP '(' expression [, expression ]* ')'
+  | GROUPING SETS '(' groupItem [, groupItem ]* ')'
 ```
 
 For a better definition of SQL queries within a Java String, Flink SQL uses a lexical policy similar to Java:
@@ -3762,6 +3764,50 @@ MIN(value)
 <table class="table table-bordered">
   <thead>
     <tr>
+      <th class="text-left" style="width: 40%">Grouping functions</th>
+      <th class="text-center">Description</th>
+    </tr>
+  </thead>
+
+  <tbody>
+    <tr>
+      <td>
+        {% highlight text %}
+GROUP_ID()
+{% endhighlight %}
+      </td>
+      <td>
+        <p>Returns an integer that uniquely identifies the combination of grouping keys.</p>
+      </td>
+    </tr>
+
+    <tr>
+      <td>
+        {% highlight text %}
+GROUPING(expression)
+{% endhighlight %}
+      </td>
+      <td>
+        <p>Returns 1 if <i>expression</i> is rolled up in the current row\u2019s grouping set, 0 otherwise.</p>
+      </td>
+    </tr>
+
+    <tr>
+      <td>
+        {% highlight text %}
+GROUPING_ID(expression [, expression]* )
+{% endhighlight %}
+      </td>
+      <td>
+        <p>Returns a bit vector of the given grouping expressions.</p>
+      </td>
+    </tr>
+  </tbody>
+</table>
+
+<table class="table table-bordered">
+  <thead>
+    <tr>
       <th class="text-left" style="width: 40%">Value access functions</th>
       <th class="text-center">Description</th>
     </tr>

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
index 48dbce6..d85540a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
@@ -377,9 +377,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
 
   lazy val prefixed: PackratParser[Expression] =
     prefixArray | prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg |
-      prefixStart | prefixEnd |
-      prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs | prefixIf | prefixExtract |
-      prefixFloor | prefixCeil | prefixGet | prefixFlattening |
+      prefixStart | prefixEnd | prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs |
+      prefixIf | prefixExtract | prefixFloor | prefixCeil | prefixGet | prefixFlattening |
       prefixFunctionCall | prefixFunctionCallOneArg // function call must always be at the end
 
   // suffix/prefix composite

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
index a5c42d9..6771536 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
@@ -46,12 +46,13 @@ class DataSetAggregate(
     namedAggregates: Seq[CalcitePair[AggregateCall, String]],
     rowRelDataType: RelDataType,
     inputType: RelDataType,
-    grouping: Array[Int])
+    grouping: Array[Int],
+    inGroupingSet: Boolean)
   extends SingleRel(cluster, traitSet, inputNode)
   with FlinkAggregate
   with DataSetRel {
 
-  override def deriveRowType() = rowRelDataType
+  override def deriveRowType(): RelDataType = rowRelDataType
 
   override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
     new DataSetAggregate(
@@ -61,7 +62,8 @@ class DataSetAggregate(
       namedAggregates,
       getRowType,
       inputType,
-      grouping)
+      grouping,
+      inGroupingSet)
   }
 
   override def toString: String = {
@@ -104,7 +106,8 @@ class DataSetAggregate(
       namedAggregates,
       inputType,
       rowRelDataType,
-      grouping)
+      grouping,
+      inGroupingSet)
 
     val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(
       tableEnv,

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
index 9902486..6a3d4e3 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
@@ -58,7 +58,7 @@ class DataStreamAggregate(
   with FlinkAggregate
   with DataStreamRel {
 
-  override def deriveRowType() = rowRelDataType
+  override def deriveRowType(): RelDataType = rowRelDataType
 
   override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
     new DataStreamAggregate(
@@ -242,6 +242,7 @@ class DataStreamAggregate(
         }
       }
     }
+
     // if the expected type is not a Row, inject a mapper to convert to the expected type
     expectedType match {
       case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] =>

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
index d634a6c..d1f932e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
@@ -23,23 +23,21 @@ import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.convert.ConverterRule
 import org.apache.calcite.rel.logical.LogicalAggregate
 import org.apache.flink.table.api.TableException
-import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention}
-
+import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention, DataSetUnion}
 import scala.collection.JavaConversions._
 
 class DataSetAggregateRule
   extends ConverterRule(
-      classOf[LogicalAggregate],
-      Convention.NONE,
-      DataSetConvention.INSTANCE,
-      "DataSetAggregateRule")
-  {
+    classOf[LogicalAggregate],
+    Convention.NONE,
+    DataSetConvention.INSTANCE,
+    "DataSetAggregateRule") {
 
   override def matches(call: RelOptRuleCall): Boolean = {
     val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
 
-    //for non grouped agg sets should attach null row to source data
-    //need apply DataSetAggregateWithNullValuesRule
+    // for non-grouped agg sets we attach null row to source data
+    // we need to apply DataSetAggregateWithNullValuesRule
     if (agg.getGroupSet.isEmpty) {
       return false
     }
@@ -50,13 +48,7 @@ class DataSetAggregateRule
       throw TableException("DISTINCT aggregates are currently not supported.")
     }
 
-    // check if we have grouping sets
-    val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet
-    if (groupSets || agg.indicator) {
-      throw TableException("GROUPING SETS are currently not supported.")
-    }
-
-    !distinctAggs && !groupSets && !agg.indicator
+    !distinctAggs
   }
 
   override def convert(rel: RelNode): RelNode = {
@@ -64,16 +56,43 @@ class DataSetAggregateRule
     val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
     val convInput: RelNode = RelOptRule.convert(agg.getInput, DataSetConvention.INSTANCE)
 
-    new DataSetAggregate(
-      rel.getCluster,
-      traitSet,
-      convInput,
-      agg.getNamedAggCalls,
-      rel.getRowType,
-      agg.getInput.getRowType,
-      agg.getGroupSet.toArray)
+    if (agg.indicator) {
+        agg.groupSets.map(set =>
+          new DataSetAggregate(
+            rel.getCluster,
+            traitSet,
+            convInput,
+            agg.getNamedAggCalls,
+            rel.getRowType,
+            agg.getInput.getRowType,
+            set.toArray,
+            inGroupingSet = true
+          ).asInstanceOf[RelNode]
+        ).reduce(
+          (rel1, rel2) => {
+            new DataSetUnion(
+              rel.getCluster,
+              traitSet,
+              rel1,
+              rel2,
+              rel.getRowType
+            )
+          }
+        )
+    } else {
+      new DataSetAggregate(
+        rel.getCluster,
+        traitSet,
+        convInput,
+        agg.getNamedAggCalls,
+        rel.getRowType,
+        agg.getInput.getRowType,
+        agg.getGroupSet.toArray,
+        inGroupingSet = false
+      )
     }
   }
+}
 
 object DataSetAggregateRule {
   val INSTANCE: RelOptRule = new DataSetAggregateRule

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
index b708af4..e8084fa 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
@@ -29,22 +29,21 @@ import org.apache.flink.table.api.TableException
 import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention}
 
 /**
-  * Rule for insert [[org.apache.flink.types.Row]] with null records into a [[DataSetAggregate]]
-  * Rule apply for non grouped aggregate query
+  * Rule for insert [[org.apache.flink.types.Row]] with null records into a [[DataSetAggregate]].
+  * Rule apply for non grouped aggregate query.
   */
 class DataSetAggregateWithNullValuesRule
   extends ConverterRule(
     classOf[LogicalAggregate],
     Convention.NONE,
     DataSetConvention.INSTANCE,
-    "DataSetAggregateWithNullValuesRule")
-{
+    "DataSetAggregateWithNullValuesRule") {
 
   override def matches(call: RelOptRuleCall): Boolean = {
     val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
 
-    //for grouped agg sets shouldn't attach of null row
-    //need apply other rules. e.g. [[DataSetAggregateRule]]
+    // group sets shouldn't attach a null row
+    // we need to apply other rules. i.e. DataSetAggregateRule
     if (!agg.getGroupSet.isEmpty) {
       return false
     }
@@ -55,12 +54,7 @@ class DataSetAggregateWithNullValuesRule
       throw TableException("DISTINCT aggregates are currently not supported.")
     }
 
-    // check if we have grouping sets
-    val groupSets = agg.getGroupSets.size() == 0 || agg.getGroupSets.get(0) != agg.getGroupSet
-    if (groupSets || agg.indicator) {
-      throw TableException("GROUPING SETS are currently not supported.")
-    }
-    !distinctAggs && !groupSets && !agg.indicator
+    !distinctAggs
   }
 
   override def convert(rel: RelNode): RelNode = {
@@ -87,7 +81,8 @@ class DataSetAggregateWithNullValuesRule
       agg.getNamedAggCalls,
       rel.getRowType,
       agg.getInput.getRowType,
-      agg.getGroupSet.toArray
+      agg.getGroupSet.toArray,
+      inGroupingSet = false
     )
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala
index bf8a18e..09f05d7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala
@@ -33,8 +33,7 @@ class DataStreamAggregateRule
       classOf[LogicalWindowAggregate],
       Convention.NONE,
       DataStreamConvention.INSTANCE,
-      "DataStreamAggregateRule")
-  {
+      "DataStreamAggregateRule") {
 
   override def matches(call: RelOptRuleCall): Boolean = {
     val agg: LogicalWindowAggregate = call.rel(0).asInstanceOf[LogicalWindowAggregate]
@@ -75,4 +74,3 @@ class DataStreamAggregateRule
 object DataStreamAggregateRule {
   val INSTANCE: RelOptRule = new DataStreamAggregateRule
 }
-

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
index 21a96e0..0033ff7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
@@ -37,7 +37,7 @@ class AggregateMapFunction[IN, OUT](
   override def open(config: Configuration) {
     Preconditions.checkNotNull(aggregates)
     Preconditions.checkNotNull(aggFields)
-    Preconditions.checkArgument(aggregates.size == aggFields.size)
+    Preconditions.checkArgument(aggregates.length == aggFields.length)
     val partialRowLength = groupingKeys.length +
         aggregates.map(_.intermediateDataType.length).sum
     output = new Row(partialRowLength)
@@ -46,11 +46,11 @@ class AggregateMapFunction[IN, OUT](
   override def map(value: IN): OUT = {
     
     val input = value.asInstanceOf[Row]
-    for (i <- 0 until aggregates.length) {
+    for (i <- aggregates.indices) {
       val fieldValue = input.getField(aggFields(i))
       aggregates(i).prepare(fieldValue, output)
     }
-    for (i <- 0 until groupingKeys.length) {
+    for (i <- groupingKeys.indices) {
       output.setField(i, input.getField(groupingKeys(i)))
     }
     output.asInstanceOf[OUT]

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
index 31b85cd..5237ecf 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
@@ -25,28 +25,31 @@ import org.apache.flink.types.Row
 
 import scala.collection.JavaConversions._
 
-
 /**
  * It wraps the aggregate logic inside of
  * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and
  * [[org.apache.flink.api.java.operators.GroupCombineOperator]]
  *
- * @param aggregates   The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- *                         and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- *                         index in output Row.
+ * @param aggregates          The aggregate functions.
+ * @param groupKeysMapping    The index mapping of group keys between intermediate aggregate Row
+ *                            and output Row.
+ * @param aggregateMapping    The index mapping between aggregate function list and aggregated value
+ *                            index in output Row.
+ * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
+ *                            Row and output Row.
  */
 class AggregateReduceCombineFunction(
     private val aggregates: Array[Aggregate[_ <: Any]],
     private val groupKeysMapping: Array[(Int, Int)],
     private val aggregateMapping: Array[(Int, Int)],
+    private val groupingSetsMapping: Array[(Int, Int)],
     private val intermediateRowArity: Int,
     private val finalRowArity: Int)
   extends AggregateReduceGroupFunction(
     aggregates,
     groupKeysMapping,
     aggregateMapping,
+    groupingSetsMapping,
     intermediateRowArity,
     finalRowArity)
   with CombineFunction[Row, Row] {

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
index c1efebb..c147629 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
@@ -20,38 +20,45 @@ package org.apache.flink.table.runtime.aggregate
 import java.lang.Iterable
 
 import org.apache.flink.api.common.functions.RichGroupReduceFunction
-import org.apache.flink.types.Row
 import org.apache.flink.configuration.Configuration
+import org.apache.flink.types.Row
 import org.apache.flink.util.{Collector, Preconditions}
 
 import scala.collection.JavaConversions._
 
 /**
- * It wraps the aggregate logic inside of 
+ * It wraps the aggregate logic inside of
  * [[org.apache.flink.api.java.operators.GroupReduceOperator]].
  *
- * @param aggregates   The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row 
- *                         and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- *                         index in output Row.
+ * @param aggregates          The aggregate functions.
+ * @param groupKeysMapping    The index mapping of group keys between intermediate aggregate Row
+ *                            and output Row.
+ * @param aggregateMapping    The index mapping between aggregate function list and aggregated value
+ *                            index in output Row.
+ * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
+ *                            Row and output Row.
  */
 class AggregateReduceGroupFunction(
     private val aggregates: Array[Aggregate[_ <: Any]],
     private val groupKeysMapping: Array[(Int, Int)],
     private val aggregateMapping: Array[(Int, Int)],
+    private val groupingSetsMapping: Array[(Int, Int)],
     private val intermediateRowArity: Int,
     private val finalRowArity: Int)
   extends RichGroupReduceFunction[Row, Row] {
 
   protected var aggregateBuffer: Row = _
   private var output: Row = _
+  private var intermediateGroupKeys: Option[Array[Int]] = None
 
   override def open(config: Configuration) {
     Preconditions.checkNotNull(aggregates)
     Preconditions.checkNotNull(groupKeysMapping)
     aggregateBuffer = new Row(intermediateRowArity)
     output = new Row(finalRowArity)
+    if (!groupingSetsMapping.isEmpty) {
+      intermediateGroupKeys = Some(groupKeysMapping.map(_._1))
+    }
   }
 
   /**
@@ -87,6 +94,14 @@ class AggregateReduceGroupFunction(
         output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
     }
 
+    // Evaluate additional values of grouping sets
+    if (intermediateGroupKeys.isDefined) {
+      groupingSetsMapping.foreach {
+        case (inputIndex, outputIndex) =>
+          output.setField(outputIndex, !intermediateGroupKeys.get.contains(inputIndex))
+      }
+    }
+
     out.collect(output)
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 1e48288..6b7e03a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -241,10 +241,12 @@ object AggregateUtil {
     *
     */
   private[flink] def createAggregateGroupReduceFunction(
-    namedAggregates: Seq[CalcitePair[AggregateCall, String]],
-    inputType: RelDataType,
-    outputType: RelDataType,
-    groupings: Array[Int]): RichGroupReduceFunction[Row, Row] = {
+      namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+      inputType: RelDataType,
+      outputType: RelDataType,
+      groupings: Array[Int],
+      inGroupingSet: Boolean)
+    : RichGroupReduceFunction[Row, Row] = {
 
     val aggregates = transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
@@ -258,6 +260,12 @@ object AggregateUtil {
         outputType,
         groupings)
 
+    val groupingSetsMapping: Array[(Int, Int)] = if (inGroupingSet) {
+      getGroupingSetsIndicatorMapping(inputType, outputType)
+    } else {
+      Array()
+    }
+
     val allPartialAggregate: Boolean = aggregates.forall(_.supportPartial)
 
     val intermediateRowArity = groupings.length +
@@ -269,6 +277,7 @@ object AggregateUtil {
           aggregates,
           groupingOffsetMapping,
           aggOffsetMapping,
+          groupingSetsMapping,
           intermediateRowArity,
           outputType.getFieldCount)
       }
@@ -277,6 +286,7 @@ object AggregateUtil {
           aggregates,
           groupingOffsetMapping,
           aggOffsetMapping,
+          groupingSetsMapping,
           intermediateRowArity,
           outputType.getFieldCount)
       }
@@ -329,7 +339,8 @@ object AggregateUtil {
         namedAggregates,
         inputType,
         outputType,
-        groupings)
+        groupings,
+        inGroupingSet = false)
 
     if (isTimeWindow(window)) {
       val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
@@ -358,7 +369,8 @@ object AggregateUtil {
         namedAggregates,
         inputType,
         outputType,
-        groupings)
+        groupings,
+        inGroupingSet = false)
 
     if (isTimeWindow(window)) {
       val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
@@ -371,7 +383,7 @@ object AggregateUtil {
 
   /**
     * Create an [[AllWindowFunction]] to finalize incrementally pre-computed non-partitioned
-    * window aggreagtes.
+    * window aggregates.
     */
   private[flink] def createAllWindowIncrementalAggregationFunction(
     window: LogicalWindow,
@@ -495,6 +507,51 @@ object AggregateUtil {
     (groupingOffsetMapping, aggOffsetMapping)
   }
 
+  /**
+    * Determines the mapping of grouping keys to boolean indicators that describe the
+    * current grouping set.
+    *
+    * E.g.: Given we group on f1 and f2 of the input type, the output type contains two
+    * boolean indicator fields i$f1 and i$f2.
+    */
+  private def getGroupingSetsIndicatorMapping(
+      inputType: RelDataType,
+      outputType: RelDataType)
+    : Array[(Int, Int)] = {
+
+    val inputFields = inputType.getFieldList.map(_.getName)
+
+    // map from field -> i$field or field -> i$field_0
+    val groupingFields = inputFields.map(inputFieldName => {
+        val base = "i$" + inputFieldName
+        var name = base
+        var i = 0
+        while (inputFields.contains(name)) {
+          name = base + "_" + i // if i$XXX is already a field it will be suffixed by _NUMBER
+          i = i + 1
+        }
+        inputFieldName -> name
+      }).toMap
+
+    val outputFields = outputType.getFieldList
+
+    var mappingsBuffer = ArrayBuffer[(Int, Int)]()
+    for (i <- outputFields.indices) {
+      for (j <- outputFields.indices) {
+        val possibleKey = outputFields(i).getName
+        val possibleIndicator1 = outputFields(j).getName
+        // get indicator for output field
+        val possibleIndicator2 = groupingFields.getOrElse(possibleKey, null)
+
+        // check if indicator names match
+        if (possibleIndicator1 == possibleIndicator2) {
+          mappingsBuffer += ((i, j))
+        }
+      }
+    }
+    mappingsBuffer.toArray
+  }
+
   private def isTimeWindow(window: LogicalWindow) = {
     window match {
       case ProcessingTimeTumblingGroupWindow(_, size) => isTimeInterval(size.resultType)

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index f92b3a1..c00f8bb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -257,6 +257,10 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
     SqlStdOperatorTable.NOT,
     SqlStdOperatorTable.UNARY_MINUS,
     SqlStdOperatorTable.UNARY_PLUS,
+    // GROUPING FUNCTIONS
+    SqlStdOperatorTable.GROUP_ID,
+    SqlStdOperatorTable.GROUPING,
+    SqlStdOperatorTable.GROUPING_ID,
     // AGGREGATE OPERATORS
     SqlStdOperatorTable.SUM,
     SqlStdOperatorTable.COUNT,

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/GroupingSetsITCase.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/GroupingSetsITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/GroupingSetsITCase.java
new file mode 100644
index 0000000..f7111f7
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/GroupingSetsITCase.java
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.java.batch.sql;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.operators.MapOperator;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.api.java.BatchTableEnvironment;
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase;
+import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
+import org.apache.flink.test.util.TestBaseUtils;
+import org.apache.flink.types.Row;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Comparator;
+import java.util.List;
+
+@RunWith(Parameterized.class)
+public class GroupingSetsITCase extends TableProgramsTestBase {
+
+	private final static String TABLE_NAME = "MyTable";
+	private final static String TABLE_WITH_NULLS_NAME = "MyTableWithNulls";
+	private BatchTableEnvironment tableEnv;
+
+	public GroupingSetsITCase(TestExecutionMode mode, TableConfigMode tableConfigMode) {
+		super(mode, tableConfigMode);
+	}
+
+	@Before
+	public void setupTables() {
+		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		tableEnv = TableEnvironment.getTableEnvironment(env, new TableConfig());
+
+		DataSet<Tuple3<Integer, Long, String>> dataSet = CollectionDataSets.get3TupleDataSet(env);
+		tableEnv.registerDataSet(TABLE_NAME, dataSet);
+
+		MapOperator<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> dataSetWithNulls =
+			dataSet.map(new MapFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>>() {
+
+				@Override
+				public Tuple3<Integer, Long, String> map(Tuple3<Integer, Long, String> value) throws Exception {
+					if (value.f2.toLowerCase().contains("world")) {
+						value.f2 = null;
+					}
+					return value;
+				}
+			});
+		tableEnv.registerDataSet(TABLE_WITH_NULLS_NAME, dataSetWithNulls);
+	}
+
+	@Test
+	public void testGroupingSets() throws Exception {
+		String query =
+			"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+				" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+				" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+				" GROUPING_ID(f1, f2) as gid " +
+				" FROM " + TABLE_NAME +
+				" GROUP BY GROUPING SETS (f1, f2)";
+
+		String expected =
+			"1,null,1,1,0,1,0,1,1\n" +
+			"6,null,18,1,0,1,0,1,1\n" +
+			"2,null,2,1,0,1,0,1,1\n" +
+			"4,null,8,1,0,1,0,1,1\n" +
+			"5,null,13,1,0,1,0,1,1\n" +
+			"3,null,5,1,0,1,0,1,1\n" +
+			"null,Comment#11,17,2,1,0,1,0,2\n" +
+			"null,Comment#8,14,2,1,0,1,0,2\n" +
+			"null,Comment#2,8,2,1,0,1,0,2\n" +
+			"null,Comment#1,7,2,1,0,1,0,2\n" +
+			"null,Comment#14,20,2,1,0,1,0,2\n" +
+			"null,Comment#7,13,2,1,0,1,0,2\n" +
+			"null,Comment#6,12,2,1,0,1,0,2\n" +
+			"null,Comment#3,9,2,1,0,1,0,2\n" +
+			"null,Comment#12,18,2,1,0,1,0,2\n" +
+			"null,Comment#5,11,2,1,0,1,0,2\n" +
+			"null,Comment#15,21,2,1,0,1,0,2\n" +
+			"null,Comment#4,10,2,1,0,1,0,2\n" +
+			"null,Hi,1,2,1,0,1,0,2\n" +
+			"null,Comment#10,16,2,1,0,1,0,2\n" +
+			"null,Hello world,3,2,1,0,1,0,2\n" +
+			"null,I am fine.,5,2,1,0,1,0,2\n" +
+			"null,Hello world, how are you?,4,2,1,0,1,0,2\n" +
+			"null,Comment#9,15,2,1,0,1,0,2\n" +
+			"null,Comment#13,19,2,1,0,1,0,2\n" +
+			"null,Luke Skywalker,6,2,1,0,1,0,2\n" +
+			"null,Hello,2,2,1,0,1,0,2";
+
+		checkSql(query, expected);
+	}
+
+	@Test
+	public void testGroupingSetsWithNulls() throws Exception {
+		String query =
+			"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g FROM " + TABLE_WITH_NULLS_NAME +
+				" GROUP BY GROUPING SETS (f1, f2)";
+
+		String expected =
+			"6,null,18,1\n5,null,13,1\n4,null,8,1\n3,null,5,1\n2,null,2,1\n1,null,1,1\n" +
+				"null,Luke Skywalker,6,2\nnull,I am fine.,5,2\nnull,Hi,1,2\n" +
+				"null,null,3,2\nnull,Hello,2,2\nnull,Comment#9,15,2\nnull,Comment#8,14,2\n" +
+				"null,Comment#7,13,2\nnull,Comment#6,12,2\nnull,Comment#5,11,2\n" +
+				"null,Comment#4,10,2\nnull,Comment#3,9,2\nnull,Comment#2,8,2\n" +
+				"null,Comment#15,21,2\nnull,Comment#14,20,2\nnull,Comment#13,19,2\n" +
+				"null,Comment#12,18,2\nnull,Comment#11,17,2\nnull,Comment#10,16,2\n" +
+				"null,Comment#1,7,2";
+
+		checkSql(query, expected);
+	}
+
+	@Test
+	public void testCubeAsGroupingSets() throws Exception {
+		String cubeQuery =
+			"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+				" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+				" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+				" GROUPING_ID(f1, f2) as gid " +
+				" FROM " + TABLE_NAME + " GROUP BY CUBE (f1, f2)";
+
+		String groupingSetsQuery =
+			"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+				" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+				" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+				" GROUPING_ID(f1, f2) as gid " +
+				" FROM " + TABLE_NAME +
+				" GROUP BY GROUPING SETS ((f1, f2), (f1), (f2), ())";
+
+		compareSql(cubeQuery, groupingSetsQuery);
+	}
+
+	@Test
+	public void testRollupAsGroupingSets() throws Exception {
+		String rollupQuery =
+			"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+				" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+				" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+				" GROUPING_ID(f1, f2) as gid " +
+				" FROM " + TABLE_NAME + " GROUP BY ROLLUP (f1, f2)";
+
+		String groupingSetsQuery =
+			"SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+				" GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+				" GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+				" GROUPING_ID(f1, f2) as gid " +
+				" FROM " + TABLE_NAME +
+				" GROUP BY GROUPING SETS ((f1, f2), (f1), ())";
+
+		compareSql(rollupQuery, groupingSetsQuery);
+	}
+
+	/**
+	 * Execute SQL query and check results.
+	 *
+	 * @param query    SQL query.
+	 * @param expected Expected result.
+	 */
+	private void checkSql(String query, String expected) throws Exception {
+		Table resultTable = tableEnv.sql(query);
+		DataSet<Row> resultDataSet = tableEnv.toDataSet(resultTable, Row.class);
+		List<Row> results = resultDataSet.collect();
+		TestBaseUtils.compareResultAsText(results, expected);
+	}
+
+	private void compareSql(String query1, String query2) throws Exception {
+
+		// Function to map row to string
+		MapFunction<Row, String> mapFunction = new MapFunction<Row, String>() {
+
+			@Override
+			public String map(Row value) throws Exception {
+				return value == null ? "null" : value.toString();
+			}
+		};
+
+		// Execute first query and store results
+		Table resultTable1 = tableEnv.sql(query1);
+		DataSet<Row> resultDataSet1 = tableEnv.toDataSet(resultTable1, Row.class);
+		List<String> results1 = resultDataSet1.map(mapFunction).collect();
+
+		// Execute second query and store results
+		Table resultTable2 = tableEnv.sql(query2);
+		DataSet<Row> resultDataSet2 = tableEnv.toDataSet(resultTable2, Row.class);
+		List<String> results2 = resultDataSet2.map(mapFunction).collect();
+
+		// Compare results
+		TestBaseUtils.compareResultCollections(results1, results2, new Comparator<String>() {
+
+			@Override
+			public int compare(String o1, String o2) {
+				return o2 == null ? o1 == null ? 0 : 1 : o1.compareTo(o2);
+			}
+		});
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
index 4f55bee..d6f2b7b 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
@@ -245,19 +245,31 @@ class AggregationsITCase(
     tEnv.sql(sqlQuery).toDataSet[Row]
   }
 
-  @Test(expected = classOf[TableException])
+  @Test
   def testGroupingSetAggregate(): Unit = {
 
     val env = ExecutionEnvironment.getExecutionEnvironment
     val tEnv = TableEnvironment.getTableEnvironment(env, config)
 
-    val sqlQuery = "SELECT _2, _3, avg(_1) as a FROM MyTable GROUP BY GROUPING SETS (_2, _3)"
+    val sqlQuery =
+      "SELECT _2, _3, avg(_1) as a, GROUP_ID() as g FROM MyTable GROUP BY GROUPING SETS (_2, _3)"
 
     val ds = CollectionDataSets.get3TupleDataSet(env)
     tEnv.registerDataSet("MyTable", ds)
 
-    // must fail. grouping sets are not supported
-    tEnv.sql(sqlQuery).toDataSet[Row]
+    val result = tEnv.sql(sqlQuery).toDataSet[Row].collect()
+
+    val expected =
+      "6,null,18,1\n5,null,13,1\n4,null,8,1\n3,null,5,1\n2,null,2,1\n1,null,1,1\n" +
+        "null,Luke Skywalker,6,2\nnull,I am fine.,5,2\nnull,Hi,1,2\n" +
+        "null,Hello world, how are you?,4,2\nnull,Hello world,3,2\nnull,Hello,2,2\n" +
+        "null,Comment#9,15,2\nnull,Comment#8,14,2\nnull,Comment#7,13,2\n" +
+        "null,Comment#6,12,2\nnull,Comment#5,11,2\nnull,Comment#4,10,2\n" +
+        "null,Comment#3,9,2\nnull,Comment#2,8,2\nnull,Comment#15,21,2\n" +
+        "null,Comment#14,20,2\nnull,Comment#13,19,2\nnull,Comment#12,18,2\n" +
+        "null,Comment#11,17,2\nnull,Comment#10,16,2\nnull,Comment#1,7,2"
+
+    TestBaseUtils.compareResultAsText(result.asJava, expected)
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/GroupingSetsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/GroupingSetsTest.scala
new file mode 100644
index 0000000..c12e5a6
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/GroupingSetsTest.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.scala.batch.sql
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.utils.TableTestBase
+import org.apache.flink.table.utils.TableTestUtil._
+import org.junit.Test
+
+class GroupingSetsTest extends TableTestBase {
+
+  @Test
+  def testGroupingSets(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g FROM MyTable " +
+      "GROUP BY GROUPING SETS (b, c)"
+
+    val aggregate = unaryNode(
+      "DataSetCalc",
+      binaryNode(
+        "DataSetUnion",
+        unaryNode(
+          "DataSetAggregate",
+          batchTableNode(0),
+          term("groupBy", "b"),
+          term("select", "b", "AVG(a) AS c")
+        ),
+        unaryNode(
+          "DataSetAggregate",
+          batchTableNode(0),
+          term("groupBy", "c"),
+          term("select", "c AS b", "AVG(a) AS c")
+        ),
+        term("union", "b", "c", "i$b", "i$c", "a")
+      ),
+      term("select",
+        "CASE(i$b, null, b) AS b",
+        "CASE(i$c, null, c) AS c",
+        "a",
+        "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS g") // GROUP_ID()
+    )
+
+    util.verifySql(sqlQuery, aggregate)
+  }
+
+  @Test
+  def testCube(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g, " +
+      "GROUPING(b) as gb, GROUPING(c) as gc, " +
+      "GROUPING_ID(b) as gib, GROUPING_ID(c) as gic, " +
+      "GROUPING_ID(b, c) as gid " +
+      "FROM MyTable " +
+      "GROUP BY CUBE (b, c)"
+
+    val group1 = unaryNode(
+      "DataSetAggregate",
+      batchTableNode(0),
+      term("groupBy", "b, c"),
+      term("select", "b", "c",
+           "AVG(a) AS i$b")
+    )
+
+    val group2 = unaryNode(
+      "DataSetAggregate",
+      batchTableNode(0),
+      term("groupBy", "b"),
+      term("select", "b",
+           "AVG(a) AS c")
+    )
+
+    val group3 = unaryNode(
+      "DataSetAggregate",
+      batchTableNode(0),
+      term("groupBy", "c"),
+      term("select", "c AS b",
+           "AVG(a) AS c")
+    )
+
+    val group4 = unaryNode(
+      "DataSetAggregate",
+      batchTableNode(0),
+      term("select",
+           "AVG(a) AS b")
+    )
+
+    val union1 = binaryNode(
+      "DataSetUnion",
+      group1, group2,
+      term("union", "b", "c", "i$b", "i$c", "a")
+    )
+
+    val union2 = binaryNode(
+      "DataSetUnion",
+      union1, group3,
+      term("union", "b", "c", "i$b", "i$c", "a")
+    )
+
+    val union3 = binaryNode(
+      "DataSetUnion",
+      union2, group4,
+      term("union", "b", "c", "i$b", "i$c", "a")
+    )
+
+    val aggregate = unaryNode(
+      "DataSetCalc",
+      union3,
+      term("select",
+           "CASE(i$b, null, b) AS b",
+           "CASE(i$c, null, c) AS c",
+           "a",
+           "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS g", // GROUP_ID()
+           "CASE(i$b, 1, 0) AS gb", // GROUPING(b)
+           "CASE(i$c, 1, 0) AS gc", // GROUPING(c)
+           "CASE(i$b, 1, 0) AS gib", // GROUPING_ID(b)
+           "CASE(i$c, 1, 0) AS gic", // GROUPING_ID(c)
+           "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS gid") // GROUPING_ID(b, c)
+    )
+
+    util.verifySql(sqlQuery, aggregate)
+  }
+
+  @Test
+  def testRollup(): Unit = {
+    val util = batchTestUtil()
+    util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g, " +
+                   "GROUPING(b) as gb, GROUPING(c) as gc, " +
+                   "GROUPING_ID(b) as gib, GROUPING_ID(c) as gic, " +
+                   "GROUPING_ID(b, c) as gid " + " FROM MyTable " +
+                   "GROUP BY ROLLUP (b, c)"
+
+    val group1 = unaryNode(
+      "DataSetAggregate",
+      batchTableNode(0),
+      term("groupBy", "b, c"),
+      term("select", "b", "c",
+           "AVG(a) AS i$b")
+    )
+
+    val group2 = unaryNode(
+      "DataSetAggregate",
+      batchTableNode(0),
+      term("groupBy", "b"),
+      term("select", "b",
+           "AVG(a) AS c")
+    )
+
+    val group3 = unaryNode(
+      "DataSetAggregate",
+      batchTableNode(0),
+      term("select",
+           "AVG(a) AS b")
+    )
+
+    val union1 = binaryNode(
+      "DataSetUnion",
+      group1, group2,
+      term("union", "b", "c", "i$b", "i$c", "a")
+    )
+
+    val union2 = binaryNode(
+      "DataSetUnion",
+      union1, group3,
+      term("union", "b", "c", "i$b", "i$c", "a")
+    )
+
+    val aggregate = unaryNode(
+      "DataSetCalc",
+      union2,
+      term("select",
+           "CASE(i$b, null, b) AS b",
+           "CASE(i$c, null, c) AS c",
+           "a",
+           "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS g", // GROUP_ID()
+           "CASE(i$b, 1, 0) AS gb", // GROUPING(b)
+           "CASE(i$c, 1, 0) AS gc", // GROUPING(c)
+           "CASE(i$b, 1, 0) AS gib", // GROUPING_ID(b)
+           "CASE(i$c, 1, 0) AS gic", // GROUPING_ID(c)
+           "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS gid") // GROUPING_ID(b, c)
+    )
+
+    util.verifySql(sqlQuery, aggregate)
+  }
+}