You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tw...@apache.org on 2017/01/16 15:41:43 UTC
flink git commit: [FLINK-5303] [table] Add CUBE/ROLLUP/GROUPING SETS
operator in SQL
Repository: flink
Updated Branches:
refs/heads/master 68228ffdd -> ef8cdfe59
[FLINK-5303] [table] Add CUBE/ROLLUP/GROUPING SETS operator in SQL
This closes #2976.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ef8cdfe5
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ef8cdfe5
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ef8cdfe5
Branch: refs/heads/master
Commit: ef8cdfe5930201f79c78f34cc9f462b4e88b3da1
Parents: 68228ff
Author: Aleksandr Chermenin <al...@epam.com>
Authored: Wed Dec 7 10:57:04 2016 +0300
Committer: twalthr <tw...@apache.org>
Committed: Mon Jan 16 16:40:01 2017 +0100
----------------------------------------------------------------------
docs/dev/table_api.md | 50 ++++-
.../table/expressions/ExpressionParser.scala | 5 +-
.../plan/nodes/dataset/DataSetAggregate.scala | 11 +-
.../nodes/datastream/DataStreamAggregate.scala | 3 +-
.../rules/dataSet/DataSetAggregateRule.scala | 67 ++++--
.../DataSetAggregateWithNullValuesRule.scala | 21 +-
.../datastream/DataStreamAggregateRule.scala | 4 +-
.../aggregate/AggregateMapFunction.scala | 6 +-
.../AggregateReduceCombineFunction.scala | 15 +-
.../AggregateReduceGroupFunction.scala | 29 ++-
.../table/runtime/aggregate/AggregateUtil.scala | 71 +++++-
.../flink/table/validate/FunctionCatalog.scala | 4 +
.../api/java/batch/sql/GroupingSetsITCase.java | 219 +++++++++++++++++++
.../scala/batch/sql/AggregationsITCase.scala | 20 +-
.../api/scala/batch/sql/GroupingSetsTest.scala | 206 +++++++++++++++++
15 files changed, 654 insertions(+), 77 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/docs/dev/table_api.md
----------------------------------------------------------------------
diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md
index 0fcd88d..acabfcf 100644
--- a/docs/dev/table_api.md
+++ b/docs/dev/table_api.md
@@ -1334,7 +1334,7 @@ Among others, the following SQL features are not supported, yet:
- Interval arithmetic is currenly limited
- Distinct aggregates (e.g., `COUNT(DISTINCT name)`)
- Non-equi joins and Cartesian products
-- Grouping sets
+- Efficient grouping sets
*Note: Tables are joined in the order in which they are specified in the `FROM` clause. In some cases the table order must be manually tweaked to resolve Cartesian products.*
@@ -1442,7 +1442,9 @@ groupItem:
expression
| '(' ')'
| '(' expression [, expression ]* ')'
-
+ | CUBE '(' expression [, expression ]* ')'
+ | ROLLUP '(' expression [, expression ]* ')'
+ | GROUPING SETS '(' groupItem [, groupItem ]* ')'
```
For a better definition of SQL queries within a Java String, Flink SQL uses a lexical policy similar to Java:
@@ -3762,6 +3764,50 @@ MIN(value)
<table class="table table-bordered">
<thead>
<tr>
+ <th class="text-left" style="width: 40%">Grouping functions</th>
+ <th class="text-center">Description</th>
+ </tr>
+ </thead>
+
+ <tbody>
+ <tr>
+ <td>
+ {% highlight text %}
+GROUP_ID()
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns an integer that uniquely identifies the combination of grouping keys.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight text %}
+GROUPING(expression)
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns 1 if <i>expression</i> is rolled up in the current row\u2019s grouping set, 0 otherwise.</p>
+ </td>
+ </tr>
+
+ <tr>
+ <td>
+ {% highlight text %}
+GROUPING_ID(expression [, expression]* )
+{% endhighlight %}
+ </td>
+ <td>
+ <p>Returns a bit vector of the given grouping expressions.</p>
+ </td>
+ </tr>
+ </tbody>
+</table>
+
+<table class="table table-bordered">
+ <thead>
+ <tr>
<th class="text-left" style="width: 40%">Value access functions</th>
<th class="text-center">Description</th>
</tr>
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
index 48dbce6..d85540a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ExpressionParser.scala
@@ -377,9 +377,8 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val prefixed: PackratParser[Expression] =
prefixArray | prefixSum | prefixMin | prefixMax | prefixCount | prefixAvg |
- prefixStart | prefixEnd |
- prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs | prefixIf | prefixExtract |
- prefixFloor | prefixCeil | prefixGet | prefixFlattening |
+ prefixStart | prefixEnd | prefixCast | prefixAs | prefixTrim | prefixTrimWithoutArgs |
+ prefixIf | prefixExtract | prefixFloor | prefixCeil | prefixGet | prefixFlattening |
prefixFunctionCall | prefixFunctionCallOneArg // function call must always be at the end
// suffix/prefix composite
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
index a5c42d9..6771536 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
@@ -46,12 +46,13 @@ class DataSetAggregate(
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
rowRelDataType: RelDataType,
inputType: RelDataType,
- grouping: Array[Int])
+ grouping: Array[Int],
+ inGroupingSet: Boolean)
extends SingleRel(cluster, traitSet, inputNode)
with FlinkAggregate
with DataSetRel {
- override def deriveRowType() = rowRelDataType
+ override def deriveRowType(): RelDataType = rowRelDataType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataSetAggregate(
@@ -61,7 +62,8 @@ class DataSetAggregate(
namedAggregates,
getRowType,
inputType,
- grouping)
+ grouping,
+ inGroupingSet)
}
override def toString: String = {
@@ -104,7 +106,8 @@ class DataSetAggregate(
namedAggregates,
inputType,
rowRelDataType,
- grouping)
+ grouping,
+ inGroupingSet)
val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(
tableEnv,
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
index 9902486..6a3d4e3 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala
@@ -58,7 +58,7 @@ class DataStreamAggregate(
with FlinkAggregate
with DataStreamRel {
- override def deriveRowType() = rowRelDataType
+ override def deriveRowType(): RelDataType = rowRelDataType
override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = {
new DataStreamAggregate(
@@ -242,6 +242,7 @@ class DataStreamAggregate(
}
}
}
+
// if the expected type is not a Row, inject a mapper to convert to the expected type
expectedType match {
case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] =>
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
index d634a6c..d1f932e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala
@@ -23,23 +23,21 @@ import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.convert.ConverterRule
import org.apache.calcite.rel.logical.LogicalAggregate
import org.apache.flink.table.api.TableException
-import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention}
-
+import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention, DataSetUnion}
import scala.collection.JavaConversions._
class DataSetAggregateRule
extends ConverterRule(
- classOf[LogicalAggregate],
- Convention.NONE,
- DataSetConvention.INSTANCE,
- "DataSetAggregateRule")
- {
+ classOf[LogicalAggregate],
+ Convention.NONE,
+ DataSetConvention.INSTANCE,
+ "DataSetAggregateRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
- //for non grouped agg sets should attach null row to source data
- //need apply DataSetAggregateWithNullValuesRule
+ // for non-grouped agg sets we attach null row to source data
+ // we need to apply DataSetAggregateWithNullValuesRule
if (agg.getGroupSet.isEmpty) {
return false
}
@@ -50,13 +48,7 @@ class DataSetAggregateRule
throw TableException("DISTINCT aggregates are currently not supported.")
}
- // check if we have grouping sets
- val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet
- if (groupSets || agg.indicator) {
- throw TableException("GROUPING SETS are currently not supported.")
- }
-
- !distinctAggs && !groupSets && !agg.indicator
+ !distinctAggs
}
override def convert(rel: RelNode): RelNode = {
@@ -64,16 +56,43 @@ class DataSetAggregateRule
val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
val convInput: RelNode = RelOptRule.convert(agg.getInput, DataSetConvention.INSTANCE)
- new DataSetAggregate(
- rel.getCluster,
- traitSet,
- convInput,
- agg.getNamedAggCalls,
- rel.getRowType,
- agg.getInput.getRowType,
- agg.getGroupSet.toArray)
+ if (agg.indicator) {
+ agg.groupSets.map(set =>
+ new DataSetAggregate(
+ rel.getCluster,
+ traitSet,
+ convInput,
+ agg.getNamedAggCalls,
+ rel.getRowType,
+ agg.getInput.getRowType,
+ set.toArray,
+ inGroupingSet = true
+ ).asInstanceOf[RelNode]
+ ).reduce(
+ (rel1, rel2) => {
+ new DataSetUnion(
+ rel.getCluster,
+ traitSet,
+ rel1,
+ rel2,
+ rel.getRowType
+ )
+ }
+ )
+ } else {
+ new DataSetAggregate(
+ rel.getCluster,
+ traitSet,
+ convInput,
+ agg.getNamedAggCalls,
+ rel.getRowType,
+ agg.getInput.getRowType,
+ agg.getGroupSet.toArray,
+ inGroupingSet = false
+ )
}
}
+}
object DataSetAggregateRule {
val INSTANCE: RelOptRule = new DataSetAggregateRule
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
index b708af4..e8084fa 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala
@@ -29,22 +29,21 @@ import org.apache.flink.table.api.TableException
import org.apache.flink.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention}
/**
- * Rule for insert [[org.apache.flink.types.Row]] with null records into a [[DataSetAggregate]]
- * Rule apply for non grouped aggregate query
+ * Rule for insert [[org.apache.flink.types.Row]] with null records into a [[DataSetAggregate]].
+ * Rule apply for non grouped aggregate query.
*/
class DataSetAggregateWithNullValuesRule
extends ConverterRule(
classOf[LogicalAggregate],
Convention.NONE,
DataSetConvention.INSTANCE,
- "DataSetAggregateWithNullValuesRule")
-{
+ "DataSetAggregateWithNullValuesRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate]
- //for grouped agg sets shouldn't attach of null row
- //need apply other rules. e.g. [[DataSetAggregateRule]]
+ // group sets shouldn't attach a null row
+ // we need to apply other rules. i.e. DataSetAggregateRule
if (!agg.getGroupSet.isEmpty) {
return false
}
@@ -55,12 +54,7 @@ class DataSetAggregateWithNullValuesRule
throw TableException("DISTINCT aggregates are currently not supported.")
}
- // check if we have grouping sets
- val groupSets = agg.getGroupSets.size() == 0 || agg.getGroupSets.get(0) != agg.getGroupSet
- if (groupSets || agg.indicator) {
- throw TableException("GROUPING SETS are currently not supported.")
- }
- !distinctAggs && !groupSets && !agg.indicator
+ !distinctAggs
}
override def convert(rel: RelNode): RelNode = {
@@ -87,7 +81,8 @@ class DataSetAggregateWithNullValuesRule
agg.getNamedAggCalls,
rel.getRowType,
agg.getInput.getRowType,
- agg.getGroupSet.toArray
+ agg.getGroupSet.toArray,
+ inGroupingSet = false
)
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala
index bf8a18e..09f05d7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala
@@ -33,8 +33,7 @@ class DataStreamAggregateRule
classOf[LogicalWindowAggregate],
Convention.NONE,
DataStreamConvention.INSTANCE,
- "DataStreamAggregateRule")
- {
+ "DataStreamAggregateRule") {
override def matches(call: RelOptRuleCall): Boolean = {
val agg: LogicalWindowAggregate = call.rel(0).asInstanceOf[LogicalWindowAggregate]
@@ -75,4 +74,3 @@ class DataStreamAggregateRule
object DataStreamAggregateRule {
val INSTANCE: RelOptRule = new DataStreamAggregateRule
}
-
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
index 21a96e0..0033ff7 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateMapFunction.scala
@@ -37,7 +37,7 @@ class AggregateMapFunction[IN, OUT](
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(aggFields)
- Preconditions.checkArgument(aggregates.size == aggFields.size)
+ Preconditions.checkArgument(aggregates.length == aggFields.length)
val partialRowLength = groupingKeys.length +
aggregates.map(_.intermediateDataType.length).sum
output = new Row(partialRowLength)
@@ -46,11 +46,11 @@ class AggregateMapFunction[IN, OUT](
override def map(value: IN): OUT = {
val input = value.asInstanceOf[Row]
- for (i <- 0 until aggregates.length) {
+ for (i <- aggregates.indices) {
val fieldValue = input.getField(aggFields(i))
aggregates(i).prepare(fieldValue, output)
}
- for (i <- 0 until groupingKeys.length) {
+ for (i <- groupingKeys.indices) {
output.setField(i, input.getField(groupingKeys(i)))
}
output.asInstanceOf[OUT]
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
index 31b85cd..5237ecf 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceCombineFunction.scala
@@ -25,28 +25,31 @@ import org.apache.flink.types.Row
import scala.collection.JavaConversions._
-
/**
* It wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupReduceOperator]] and
* [[org.apache.flink.api.java.operators.GroupCombineOperator]]
*
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
+ * @param aggregates The aggregate functions.
+ * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
+ * and output Row.
+ * @param aggregateMapping The index mapping between aggregate function list and aggregated value
+ * index in output Row.
+ * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
+ * Row and output Row.
*/
class AggregateReduceCombineFunction(
private val aggregates: Array[Aggregate[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
+ private val groupingSetsMapping: Array[(Int, Int)],
private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends AggregateReduceGroupFunction(
aggregates,
groupKeysMapping,
aggregateMapping,
+ groupingSetsMapping,
intermediateRowArity,
finalRowArity)
with CombineFunction[Row, Row] {
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
index c1efebb..c147629 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateReduceGroupFunction.scala
@@ -20,38 +20,45 @@ package org.apache.flink.table.runtime.aggregate
import java.lang.Iterable
import org.apache.flink.api.common.functions.RichGroupReduceFunction
-import org.apache.flink.types.Row
import org.apache.flink.configuration.Configuration
+import org.apache.flink.types.Row
import org.apache.flink.util.{Collector, Preconditions}
import scala.collection.JavaConversions._
/**
- * It wraps the aggregate logic inside of
+ * It wraps the aggregate logic inside of
* [[org.apache.flink.api.java.operators.GroupReduceOperator]].
*
- * @param aggregates The aggregate functions.
- * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
- * and output Row.
- * @param aggregateMapping The index mapping between aggregate function list and aggregated value
- * index in output Row.
+ * @param aggregates The aggregate functions.
+ * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row
+ * and output Row.
+ * @param aggregateMapping The index mapping between aggregate function list and aggregated value
+ * index in output Row.
+ * @param groupingSetsMapping The index mapping of keys in grouping sets between intermediate
+ * Row and output Row.
*/
class AggregateReduceGroupFunction(
private val aggregates: Array[Aggregate[_ <: Any]],
private val groupKeysMapping: Array[(Int, Int)],
private val aggregateMapping: Array[(Int, Int)],
+ private val groupingSetsMapping: Array[(Int, Int)],
private val intermediateRowArity: Int,
private val finalRowArity: Int)
extends RichGroupReduceFunction[Row, Row] {
protected var aggregateBuffer: Row = _
private var output: Row = _
+ private var intermediateGroupKeys: Option[Array[Int]] = None
override def open(config: Configuration) {
Preconditions.checkNotNull(aggregates)
Preconditions.checkNotNull(groupKeysMapping)
aggregateBuffer = new Row(intermediateRowArity)
output = new Row(finalRowArity)
+ if (!groupingSetsMapping.isEmpty) {
+ intermediateGroupKeys = Some(groupKeysMapping.map(_._1))
+ }
}
/**
@@ -87,6 +94,14 @@ class AggregateReduceGroupFunction(
output.setField(after, aggregates(previous).evaluate(aggregateBuffer))
}
+ // Evaluate additional values of grouping sets
+ if (intermediateGroupKeys.isDefined) {
+ groupingSetsMapping.foreach {
+ case (inputIndex, outputIndex) =>
+ output.setField(outputIndex, !intermediateGroupKeys.get.contains(inputIndex))
+ }
+ }
+
out.collect(output)
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 1e48288..6b7e03a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -241,10 +241,12 @@ object AggregateUtil {
*
*/
private[flink] def createAggregateGroupReduceFunction(
- namedAggregates: Seq[CalcitePair[AggregateCall, String]],
- inputType: RelDataType,
- outputType: RelDataType,
- groupings: Array[Int]): RichGroupReduceFunction[Row, Row] = {
+ namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ inputType: RelDataType,
+ outputType: RelDataType,
+ groupings: Array[Int],
+ inGroupingSet: Boolean)
+ : RichGroupReduceFunction[Row, Row] = {
val aggregates = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
@@ -258,6 +260,12 @@ object AggregateUtil {
outputType,
groupings)
+ val groupingSetsMapping: Array[(Int, Int)] = if (inGroupingSet) {
+ getGroupingSetsIndicatorMapping(inputType, outputType)
+ } else {
+ Array()
+ }
+
val allPartialAggregate: Boolean = aggregates.forall(_.supportPartial)
val intermediateRowArity = groupings.length +
@@ -269,6 +277,7 @@ object AggregateUtil {
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
+ groupingSetsMapping,
intermediateRowArity,
outputType.getFieldCount)
}
@@ -277,6 +286,7 @@ object AggregateUtil {
aggregates,
groupingOffsetMapping,
aggOffsetMapping,
+ groupingSetsMapping,
intermediateRowArity,
outputType.getFieldCount)
}
@@ -329,7 +339,8 @@ object AggregateUtil {
namedAggregates,
inputType,
outputType,
- groupings)
+ groupings,
+ inGroupingSet = false)
if (isTimeWindow(window)) {
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
@@ -358,7 +369,8 @@ object AggregateUtil {
namedAggregates,
inputType,
outputType,
- groupings)
+ groupings,
+ inGroupingSet = false)
if (isTimeWindow(window)) {
val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
@@ -371,7 +383,7 @@ object AggregateUtil {
/**
* Create an [[AllWindowFunction]] to finalize incrementally pre-computed non-partitioned
- * window aggreagtes.
+ * window aggregates.
*/
private[flink] def createAllWindowIncrementalAggregationFunction(
window: LogicalWindow,
@@ -495,6 +507,51 @@ object AggregateUtil {
(groupingOffsetMapping, aggOffsetMapping)
}
+ /**
+ * Determines the mapping of grouping keys to boolean indicators that describe the
+ * current grouping set.
+ *
+ * E.g.: Given we group on f1 and f2 of the input type, the output type contains two
+ * boolean indicator fields i$f1 and i$f2.
+ */
+ private def getGroupingSetsIndicatorMapping(
+ inputType: RelDataType,
+ outputType: RelDataType)
+ : Array[(Int, Int)] = {
+
+ val inputFields = inputType.getFieldList.map(_.getName)
+
+ // map from field -> i$field or field -> i$field_0
+ val groupingFields = inputFields.map(inputFieldName => {
+ val base = "i$" + inputFieldName
+ var name = base
+ var i = 0
+ while (inputFields.contains(name)) {
+ name = base + "_" + i // if i$XXX is already a field it will be suffixed by _NUMBER
+ i = i + 1
+ }
+ inputFieldName -> name
+ }).toMap
+
+ val outputFields = outputType.getFieldList
+
+ var mappingsBuffer = ArrayBuffer[(Int, Int)]()
+ for (i <- outputFields.indices) {
+ for (j <- outputFields.indices) {
+ val possibleKey = outputFields(i).getName
+ val possibleIndicator1 = outputFields(j).getName
+ // get indicator for output field
+ val possibleIndicator2 = groupingFields.getOrElse(possibleKey, null)
+
+ // check if indicator names match
+ if (possibleIndicator1 == possibleIndicator2) {
+ mappingsBuffer += ((i, j))
+ }
+ }
+ }
+ mappingsBuffer.toArray
+ }
+
private def isTimeWindow(window: LogicalWindow) = {
window match {
case ProcessingTimeTumblingGroupWindow(_, size) => isTimeInterval(size.resultType)
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index f92b3a1..c00f8bb 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -257,6 +257,10 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
SqlStdOperatorTable.NOT,
SqlStdOperatorTable.UNARY_MINUS,
SqlStdOperatorTable.UNARY_PLUS,
+ // GROUPING FUNCTIONS
+ SqlStdOperatorTable.GROUP_ID,
+ SqlStdOperatorTable.GROUPING,
+ SqlStdOperatorTable.GROUPING_ID,
// AGGREGATE OPERATORS
SqlStdOperatorTable.SUM,
SqlStdOperatorTable.COUNT,
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/GroupingSetsITCase.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/GroupingSetsITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/GroupingSetsITCase.java
new file mode 100644
index 0000000..f7111f7
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/sql/GroupingSetsITCase.java
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.api.java.batch.sql;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.operators.MapOperator;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.api.java.BatchTableEnvironment;
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase;
+import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
+import org.apache.flink.test.util.TestBaseUtils;
+import org.apache.flink.types.Row;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Comparator;
+import java.util.List;
+
+@RunWith(Parameterized.class)
+public class GroupingSetsITCase extends TableProgramsTestBase {
+
+ private final static String TABLE_NAME = "MyTable";
+ private final static String TABLE_WITH_NULLS_NAME = "MyTableWithNulls";
+ private BatchTableEnvironment tableEnv;
+
+ public GroupingSetsITCase(TestExecutionMode mode, TableConfigMode tableConfigMode) {
+ super(mode, tableConfigMode);
+ }
+
+ @Before
+ public void setupTables() {
+ ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ tableEnv = TableEnvironment.getTableEnvironment(env, new TableConfig());
+
+ DataSet<Tuple3<Integer, Long, String>> dataSet = CollectionDataSets.get3TupleDataSet(env);
+ tableEnv.registerDataSet(TABLE_NAME, dataSet);
+
+ MapOperator<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> dataSetWithNulls =
+ dataSet.map(new MapFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>>() {
+
+ @Override
+ public Tuple3<Integer, Long, String> map(Tuple3<Integer, Long, String> value) throws Exception {
+ if (value.f2.toLowerCase().contains("world")) {
+ value.f2 = null;
+ }
+ return value;
+ }
+ });
+ tableEnv.registerDataSet(TABLE_WITH_NULLS_NAME, dataSetWithNulls);
+ }
+
+ @Test
+ public void testGroupingSets() throws Exception {
+ String query =
+ "SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+ " GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+ " GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+ " GROUPING_ID(f1, f2) as gid " +
+ " FROM " + TABLE_NAME +
+ " GROUP BY GROUPING SETS (f1, f2)";
+
+ String expected =
+ "1,null,1,1,0,1,0,1,1\n" +
+ "6,null,18,1,0,1,0,1,1\n" +
+ "2,null,2,1,0,1,0,1,1\n" +
+ "4,null,8,1,0,1,0,1,1\n" +
+ "5,null,13,1,0,1,0,1,1\n" +
+ "3,null,5,1,0,1,0,1,1\n" +
+ "null,Comment#11,17,2,1,0,1,0,2\n" +
+ "null,Comment#8,14,2,1,0,1,0,2\n" +
+ "null,Comment#2,8,2,1,0,1,0,2\n" +
+ "null,Comment#1,7,2,1,0,1,0,2\n" +
+ "null,Comment#14,20,2,1,0,1,0,2\n" +
+ "null,Comment#7,13,2,1,0,1,0,2\n" +
+ "null,Comment#6,12,2,1,0,1,0,2\n" +
+ "null,Comment#3,9,2,1,0,1,0,2\n" +
+ "null,Comment#12,18,2,1,0,1,0,2\n" +
+ "null,Comment#5,11,2,1,0,1,0,2\n" +
+ "null,Comment#15,21,2,1,0,1,0,2\n" +
+ "null,Comment#4,10,2,1,0,1,0,2\n" +
+ "null,Hi,1,2,1,0,1,0,2\n" +
+ "null,Comment#10,16,2,1,0,1,0,2\n" +
+ "null,Hello world,3,2,1,0,1,0,2\n" +
+ "null,I am fine.,5,2,1,0,1,0,2\n" +
+ "null,Hello world, how are you?,4,2,1,0,1,0,2\n" +
+ "null,Comment#9,15,2,1,0,1,0,2\n" +
+ "null,Comment#13,19,2,1,0,1,0,2\n" +
+ "null,Luke Skywalker,6,2,1,0,1,0,2\n" +
+ "null,Hello,2,2,1,0,1,0,2";
+
+ checkSql(query, expected);
+ }
+
+ @Test
+ public void testGroupingSetsWithNulls() throws Exception {
+ String query =
+ "SELECT f1, f2, avg(f0) as a, GROUP_ID() as g FROM " + TABLE_WITH_NULLS_NAME +
+ " GROUP BY GROUPING SETS (f1, f2)";
+
+ String expected =
+ "6,null,18,1\n5,null,13,1\n4,null,8,1\n3,null,5,1\n2,null,2,1\n1,null,1,1\n" +
+ "null,Luke Skywalker,6,2\nnull,I am fine.,5,2\nnull,Hi,1,2\n" +
+ "null,null,3,2\nnull,Hello,2,2\nnull,Comment#9,15,2\nnull,Comment#8,14,2\n" +
+ "null,Comment#7,13,2\nnull,Comment#6,12,2\nnull,Comment#5,11,2\n" +
+ "null,Comment#4,10,2\nnull,Comment#3,9,2\nnull,Comment#2,8,2\n" +
+ "null,Comment#15,21,2\nnull,Comment#14,20,2\nnull,Comment#13,19,2\n" +
+ "null,Comment#12,18,2\nnull,Comment#11,17,2\nnull,Comment#10,16,2\n" +
+ "null,Comment#1,7,2";
+
+ checkSql(query, expected);
+ }
+
+ @Test
+ public void testCubeAsGroupingSets() throws Exception {
+ String cubeQuery =
+ "SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+ " GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+ " GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+ " GROUPING_ID(f1, f2) as gid " +
+ " FROM " + TABLE_NAME + " GROUP BY CUBE (f1, f2)";
+
+ String groupingSetsQuery =
+ "SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+ " GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+ " GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+ " GROUPING_ID(f1, f2) as gid " +
+ " FROM " + TABLE_NAME +
+ " GROUP BY GROUPING SETS ((f1, f2), (f1), (f2), ())";
+
+ compareSql(cubeQuery, groupingSetsQuery);
+ }
+
+ @Test
+ public void testRollupAsGroupingSets() throws Exception {
+ String rollupQuery =
+ "SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+ " GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+ " GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+ " GROUPING_ID(f1, f2) as gid " +
+ " FROM " + TABLE_NAME + " GROUP BY ROLLUP (f1, f2)";
+
+ String groupingSetsQuery =
+ "SELECT f1, f2, avg(f0) as a, GROUP_ID() as g, " +
+ " GROUPING(f1) as gf1, GROUPING(f2) as gf2, " +
+ " GROUPING_ID(f1) as gif1, GROUPING_ID(f2) as gif2, " +
+ " GROUPING_ID(f1, f2) as gid " +
+ " FROM " + TABLE_NAME +
+ " GROUP BY GROUPING SETS ((f1, f2), (f1), ())";
+
+ compareSql(rollupQuery, groupingSetsQuery);
+ }
+
+ /**
+ * Execute SQL query and check results.
+ *
+ * @param query SQL query.
+ * @param expected Expected result.
+ */
+ private void checkSql(String query, String expected) throws Exception {
+ Table resultTable = tableEnv.sql(query);
+ DataSet<Row> resultDataSet = tableEnv.toDataSet(resultTable, Row.class);
+ List<Row> results = resultDataSet.collect();
+ TestBaseUtils.compareResultAsText(results, expected);
+ }
+
+ private void compareSql(String query1, String query2) throws Exception {
+
+ // Function to map row to string
+ MapFunction<Row, String> mapFunction = new MapFunction<Row, String>() {
+
+ @Override
+ public String map(Row value) throws Exception {
+ return value == null ? "null" : value.toString();
+ }
+ };
+
+ // Execute first query and store results
+ Table resultTable1 = tableEnv.sql(query1);
+ DataSet<Row> resultDataSet1 = tableEnv.toDataSet(resultTable1, Row.class);
+ List<String> results1 = resultDataSet1.map(mapFunction).collect();
+
+ // Execute second query and store results
+ Table resultTable2 = tableEnv.sql(query2);
+ DataSet<Row> resultDataSet2 = tableEnv.toDataSet(resultTable2, Row.class);
+ List<String> results2 = resultDataSet2.map(mapFunction).collect();
+
+ // Compare results
+ TestBaseUtils.compareResultCollections(results1, results2, new Comparator<String>() {
+
+ @Override
+ public int compare(String o1, String o2) {
+ return o2 == null ? o1 == null ? 0 : 1 : o1.compareTo(o2);
+ }
+ });
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
index 4f55bee..d6f2b7b 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/AggregationsITCase.scala
@@ -245,19 +245,31 @@ class AggregationsITCase(
tEnv.sql(sqlQuery).toDataSet[Row]
}
- @Test(expected = classOf[TableException])
+ @Test
def testGroupingSetAggregate(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tEnv = TableEnvironment.getTableEnvironment(env, config)
- val sqlQuery = "SELECT _2, _3, avg(_1) as a FROM MyTable GROUP BY GROUPING SETS (_2, _3)"
+ val sqlQuery =
+ "SELECT _2, _3, avg(_1) as a, GROUP_ID() as g FROM MyTable GROUP BY GROUPING SETS (_2, _3)"
val ds = CollectionDataSets.get3TupleDataSet(env)
tEnv.registerDataSet("MyTable", ds)
- // must fail. grouping sets are not supported
- tEnv.sql(sqlQuery).toDataSet[Row]
+ val result = tEnv.sql(sqlQuery).toDataSet[Row].collect()
+
+ val expected =
+ "6,null,18,1\n5,null,13,1\n4,null,8,1\n3,null,5,1\n2,null,2,1\n1,null,1,1\n" +
+ "null,Luke Skywalker,6,2\nnull,I am fine.,5,2\nnull,Hi,1,2\n" +
+ "null,Hello world, how are you?,4,2\nnull,Hello world,3,2\nnull,Hello,2,2\n" +
+ "null,Comment#9,15,2\nnull,Comment#8,14,2\nnull,Comment#7,13,2\n" +
+ "null,Comment#6,12,2\nnull,Comment#5,11,2\nnull,Comment#4,10,2\n" +
+ "null,Comment#3,9,2\nnull,Comment#2,8,2\nnull,Comment#15,21,2\n" +
+ "null,Comment#14,20,2\nnull,Comment#13,19,2\nnull,Comment#12,18,2\n" +
+ "null,Comment#11,17,2\nnull,Comment#10,16,2\nnull,Comment#1,7,2"
+
+ TestBaseUtils.compareResultAsText(result.asJava, expected)
}
@Test
http://git-wip-us.apache.org/repos/asf/flink/blob/ef8cdfe5/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/GroupingSetsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/GroupingSetsTest.scala
new file mode 100644
index 0000000..c12e5a6
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/GroupingSetsTest.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.api.scala.batch.sql
+
+import org.apache.flink.api.scala._
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.utils.TableTestBase
+import org.apache.flink.table.utils.TableTestUtil._
+import org.junit.Test
+
+class GroupingSetsTest extends TableTestBase {
+
+ @Test
+ def testGroupingSets(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g FROM MyTable " +
+ "GROUP BY GROUPING SETS (b, c)"
+
+ val aggregate = unaryNode(
+ "DataSetCalc",
+ binaryNode(
+ "DataSetUnion",
+ unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("groupBy", "b"),
+ term("select", "b", "AVG(a) AS c")
+ ),
+ unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("groupBy", "c"),
+ term("select", "c AS b", "AVG(a) AS c")
+ ),
+ term("union", "b", "c", "i$b", "i$c", "a")
+ ),
+ term("select",
+ "CASE(i$b, null, b) AS b",
+ "CASE(i$c, null, c) AS c",
+ "a",
+ "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS g") // GROUP_ID()
+ )
+
+ util.verifySql(sqlQuery, aggregate)
+ }
+
+ @Test
+ def testCube(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g, " +
+ "GROUPING(b) as gb, GROUPING(c) as gc, " +
+ "GROUPING_ID(b) as gib, GROUPING_ID(c) as gic, " +
+ "GROUPING_ID(b, c) as gid " +
+ "FROM MyTable " +
+ "GROUP BY CUBE (b, c)"
+
+ val group1 = unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("groupBy", "b, c"),
+ term("select", "b", "c",
+ "AVG(a) AS i$b")
+ )
+
+ val group2 = unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("groupBy", "b"),
+ term("select", "b",
+ "AVG(a) AS c")
+ )
+
+ val group3 = unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("groupBy", "c"),
+ term("select", "c AS b",
+ "AVG(a) AS c")
+ )
+
+ val group4 = unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("select",
+ "AVG(a) AS b")
+ )
+
+ val union1 = binaryNode(
+ "DataSetUnion",
+ group1, group2,
+ term("union", "b", "c", "i$b", "i$c", "a")
+ )
+
+ val union2 = binaryNode(
+ "DataSetUnion",
+ union1, group3,
+ term("union", "b", "c", "i$b", "i$c", "a")
+ )
+
+ val union3 = binaryNode(
+ "DataSetUnion",
+ union2, group4,
+ term("union", "b", "c", "i$b", "i$c", "a")
+ )
+
+ val aggregate = unaryNode(
+ "DataSetCalc",
+ union3,
+ term("select",
+ "CASE(i$b, null, b) AS b",
+ "CASE(i$c, null, c) AS c",
+ "a",
+ "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS g", // GROUP_ID()
+ "CASE(i$b, 1, 0) AS gb", // GROUPING(b)
+ "CASE(i$c, 1, 0) AS gc", // GROUPING(c)
+ "CASE(i$b, 1, 0) AS gib", // GROUPING_ID(b)
+ "CASE(i$c, 1, 0) AS gic", // GROUPING_ID(c)
+ "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS gid") // GROUPING_ID(b, c)
+ )
+
+ util.verifySql(sqlQuery, aggregate)
+ }
+
+ @Test
+ def testRollup(): Unit = {
+ val util = batchTestUtil()
+ util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c)
+
+ val sqlQuery = "SELECT b, c, avg(a) as a, GROUP_ID() as g, " +
+ "GROUPING(b) as gb, GROUPING(c) as gc, " +
+ "GROUPING_ID(b) as gib, GROUPING_ID(c) as gic, " +
+ "GROUPING_ID(b, c) as gid " + " FROM MyTable " +
+ "GROUP BY ROLLUP (b, c)"
+
+ val group1 = unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("groupBy", "b, c"),
+ term("select", "b", "c",
+ "AVG(a) AS i$b")
+ )
+
+ val group2 = unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("groupBy", "b"),
+ term("select", "b",
+ "AVG(a) AS c")
+ )
+
+ val group3 = unaryNode(
+ "DataSetAggregate",
+ batchTableNode(0),
+ term("select",
+ "AVG(a) AS b")
+ )
+
+ val union1 = binaryNode(
+ "DataSetUnion",
+ group1, group2,
+ term("union", "b", "c", "i$b", "i$c", "a")
+ )
+
+ val union2 = binaryNode(
+ "DataSetUnion",
+ union1, group3,
+ term("union", "b", "c", "i$b", "i$c", "a")
+ )
+
+ val aggregate = unaryNode(
+ "DataSetCalc",
+ union2,
+ term("select",
+ "CASE(i$b, null, b) AS b",
+ "CASE(i$c, null, c) AS c",
+ "a",
+ "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS g", // GROUP_ID()
+ "CASE(i$b, 1, 0) AS gb", // GROUPING(b)
+ "CASE(i$c, 1, 0) AS gc", // GROUPING(c)
+ "CASE(i$b, 1, 0) AS gib", // GROUPING_ID(b)
+ "CASE(i$c, 1, 0) AS gic", // GROUPING_ID(c)
+ "+(*(CASE(i$b, 1, 0), 2), CASE(i$c, 1, 0)) AS gid") // GROUPING_ID(b, c)
+ )
+
+ util.verifySql(sqlQuery, aggregate)
+ }
+}