You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/11 05:13:50 UTC

spark git commit: [SPARK-12706] [SQL] grouping() and grouping_id()

Repository: spark
Updated Branches:
  refs/heads/master 0f09f0226 -> b5761d150


[SPARK-12706] [SQL] grouping() and grouping_id()

Grouping() returns a column is aggregated or not, grouping_id() returns the aggregation levels.

grouping()/grouping_id() could be used with window function, but does not work in having/sort clause, will be fixed by another PR.

The GROUPING__ID/grouping_id() in Hive is wrong (according to docs), we also did it wrongly, this PR change that to match the behavior in most databases (also the docs of Hive).

Author: Davies Liu <da...@databricks.com>

Closes #10677 from davies/grouping.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b5761d15
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b5761d15
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b5761d15

Branch: refs/heads/master
Commit: b5761d150b66ee0ae5f1be897d9d7a1abb039884
Parents: 0f09f02
Author: Davies Liu <da...@databricks.com>
Authored: Wed Feb 10 20:13:38 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Wed Feb 10 20:13:38 2016 -0800

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 22 ++++-----
 python/pyspark/sql/functions.py                 | 44 +++++++++++++++++
 .../apache/spark/sql/catalyst/CatalystQl.scala  | 13 +++--
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 48 +++++++++++++++----
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  5 ++
 .../catalyst/analysis/FunctionRegistry.scala    |  2 +
 .../expressions/aggregate/interfaces.scala      |  1 -
 .../sql/catalyst/expressions/grouping.scala     | 23 +++++++++
 .../catalyst/plans/logical/basicOperators.scala |  2 +-
 .../scala/org/apache/spark/sql/functions.scala  | 46 ++++++++++++++++++
 .../spark/sql/DataFrameAggregateSuite.scala     | 44 +++++++++++++++++
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 50 ++++++++++++++++++++
 .../hive/execution/HiveCompatibilitySuite.scala |  8 ++--
 ...r CUBE #1-0-63b61fb3f0e74226001ad279be440864 | 12 ++---
 ...Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a | 12 ++---
 ...Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 | 20 ++++----
 ...Rollup #3-0-9257085d123728730be96b6d9fbb84ce | 20 ++++----
 17 files changed, 309 insertions(+), 63 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3a8c830..3104e41 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -887,8 +887,8 @@ class DataFrame(object):
         [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
         >>> sorted(df.groupBy(df.name).avg().collect())
         [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
-        >>> df.groupBy(['name', df.age]).count().collect()
-        [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
+        >>> sorted(df.groupBy(['name', df.age]).count().collect())
+        [Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)]
         """
         jgd = self._jdf.groupBy(self._jcols(*cols))
         from pyspark.sql.group import GroupedData
@@ -900,15 +900,15 @@ class DataFrame(object):
         Create a multi-dimensional rollup for the current :class:`DataFrame` using
         the specified columns, so we can run aggregation on them.
 
-        >>> df.rollup('name', df.age).count().show()
+        >>> df.rollup("name", df.age).count().orderBy("name", "age").show()
         +-----+----+-----+
         | name| age|count|
         +-----+----+-----+
-        |Alice|   2|    1|
-        |  Bob|   5|    1|
-        |  Bob|null|    1|
         | null|null|    2|
         |Alice|null|    1|
+        |Alice|   2|    1|
+        |  Bob|null|    1|
+        |  Bob|   5|    1|
         +-----+----+-----+
         """
         jgd = self._jdf.rollup(self._jcols(*cols))
@@ -921,17 +921,17 @@ class DataFrame(object):
         Create a multi-dimensional cube for the current :class:`DataFrame` using
         the specified columns, so we can run aggregation on them.
 
-        >>> df.cube('name', df.age).count().show()
+        >>> df.cube("name", df.age).count().orderBy("name", "age").show()
         +-----+----+-----+
         | name| age|count|
         +-----+----+-----+
+        | null|null|    2|
         | null|   2|    1|
-        |Alice|   2|    1|
-        |  Bob|   5|    1|
         | null|   5|    1|
-        |  Bob|null|    1|
-        | null|null|    2|
         |Alice|null|    1|
+        |Alice|   2|    1|
+        |  Bob|null|    1|
+        |  Bob|   5|    1|
         +-----+----+-----+
         """
         jgd = self._jdf.cube(self._jcols(*cols))

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 0d57085..680493e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -288,6 +288,50 @@ def first(col, ignorenulls=False):
     return Column(jc)
 
 
+@since(2.0)
+def grouping(col):
+    """
+    Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
+    or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+
+    >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show()
+    +-----+--------------+--------+
+    | name|grouping(name)|sum(age)|
+    +-----+--------------+--------+
+    | null|             1|       7|
+    |Alice|             0|       2|
+    |  Bob|             0|       5|
+    +-----+--------------+--------+
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.grouping(_to_java_column(col))
+    return Column(jc)
+
+
+@since(2.0)
+def grouping_id(*cols):
+    """
+    Aggregate function: returns the level of grouping, equals to
+
+       (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
+
+    Note: the list of columns should match with grouping columns exactly, or empty (means all the
+    grouping columns).
+
+    >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()
+    +-----+------------+--------+
+    | name|groupingid()|sum(age)|
+    +-----+------------+--------+
+    | null|           1|       7|
+    |Alice|           0|       2|
+    |  Bob|           0|       5|
+    +-----+------------+--------+
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column))
+    return Column(jc)
+
+
 @since(1.6)
 def input_file_name():
     """Creates a string column for the file name of the current Spark task.

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index a42360d..8099751 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -186,8 +186,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
    *
    * The bitmask denotes the grouping expressions validity for a grouping set,
    * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive)
-   * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of
-   * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively.
+   * e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of
+   * GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively.
    */
   protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = {
     val (keyASTs, setASTs) = children.partition {
@@ -198,12 +198,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
     val keys = keyASTs.map(nodeToExpr)
     val keyMap = keyASTs.zipWithIndex.toMap
 
+    val mask = (1 << keys.length) - 1
     val bitmasks: Seq[Int] = setASTs.map {
       case Token("TOK_GROUPING_SETS_EXPRESSION", columns) =>
-        columns.foldLeft(0)((bitmap, col) => {
-          val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2)
-          bitmap | 1 << keyIndex.getOrElse(
+        columns.foldLeft(mask)((bitmap, col) => {
+          val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse(
             throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list"))
+          // 0 means that the column at the given index is a grouping column, 1 means it is not,
+          // so we unset the bit in bitmap.
+          bitmap & ~(1 << (keys.length - 1 - keyIndex))
         })
       case _ => sys.error("Expect GROUPING SETS clause")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 62b241f..c0fa796 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -238,14 +238,39 @@ class Analyzer(
           }
         }.toMap
 
-        val aggregations: Seq[NamedExpression] = x.aggregations.map {
-          // If an expression is an aggregate (contains a AggregateExpression) then we dont change
-          // it so that the aggregation is computed on the unmodified value of its argument
-          // expressions.
-          case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr
-          // If not then its a grouping expression and we need to use the modified (with nulls from
-          // Expand) value of the expression.
-          case expr => expr.transformDown {
+        val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
+          // collect all the found AggregateExpression, so we can check an expression is part of
+          // any AggregateExpression or not.
+          val aggsBuffer = ArrayBuffer[Expression]()
+          // Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
+          def isPartOfAggregation(e: Expression): Boolean = {
+            aggsBuffer.exists(a => a.find(_ eq e).isDefined)
+          }
+          expr.transformDown {
+            // AggregateExpression should be computed on the unmodified value of its argument
+            // expressions, so we should not replace any references to grouping expression
+            // inside it.
+            case e: AggregateExpression =>
+              aggsBuffer += e
+              e
+            case e if isPartOfAggregation(e) => e
+            case e: GroupingID =>
+              if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) {
+                gid
+              } else {
+                throw new AnalysisException(
+                  s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
+                    s"grouping columns (${x.groupByExprs.mkString(",")})")
+              }
+            case Grouping(col: Expression) =>
+              val idx = x.groupByExprs.indexOf(col)
+              if (idx >= 0) {
+                Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)),
+                  Literal(1)), ByteType)
+              } else {
+                throw new AnalysisException(s"Column of grouping ($col) can't be found " +
+                  s"in grouping columns ${x.groupByExprs.mkString(",")}")
+              }
             case e =>
               groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e)
           }.asInstanceOf[NamedExpression]
@@ -819,8 +844,11 @@ class Analyzer(
         }
     }
 
+    private def isAggregateExpression(e: Expression): Boolean = {
+      e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID]
+    }
     def containsAggregate(condition: Expression): Boolean = {
-      condition.find(_.isInstanceOf[AggregateExpression]).isDefined
+      condition.find(isAggregateExpression).isDefined
     }
   }
 
@@ -1002,7 +1030,7 @@ class Analyzer(
         _.transform {
           // Extracts children expressions of a WindowFunction (input parameters of
           // a WindowFunction).
-          case wf : WindowFunction =>
+          case wf: WindowFunction =>
             val newChildren = wf.children.map(extractExpr)
             wf.withNewChildren(newChildren)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 4a2f2b8..fe053b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -70,6 +70,11 @@ trait CheckAnalysis {
             failAnalysis(
               s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
 
+          case g: Grouping =>
+            failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup")
+          case g: GroupingID =>
+            failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")
+
           case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
             failAnalysis(s"Distinct window functions are not supported: $w")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index d9009e3..1be97c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -291,6 +291,8 @@ object FunctionRegistry {
     // grouping sets
     expression[Cube]("cube"),
     expression[Rollup]("rollup"),
+    expression[Grouping]("grouping"),
+    expression[GroupingID]("grouping_id"),
 
     // window functions
     expression[Lead]("lead"),

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 561fa33..f88a57a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -344,4 +344,3 @@ abstract class DeclarativeAggregate
     def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
   }
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
index 2997ee8..a204060 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
@@ -41,3 +41,26 @@ trait GroupingSet extends Expression with CodegenFallback {
 case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {}
 
 case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {}
+
+/**
+  * Indicates whether a specified column expression in a GROUP BY list is aggregated or not.
+  * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set.
+  */
+case class Grouping(child: Expression) extends Expression with Unevaluable {
+  override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
+  override def children: Seq[Expression] = child :: Nil
+  override def dataType: DataType = ByteType
+  override def nullable: Boolean = false
+}
+
+/**
+  * GroupingID is a function that computes the level of grouping.
+  *
+  * If groupByExprs is empty, it means all grouping expressions in GroupingSets.
+  */
+case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Unevaluable {
+  override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
+  override def children: Seq[Expression] = groupByExprs
+  override def dataType: DataType = IntegerType
+  override def nullable: Boolean = false
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 57575f9..e8e0a78 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -412,7 +412,7 @@ private[sql] object Expand {
 
     var bit = exprs.length - 1
     while (bit >= 0) {
-      if (((bitmask >> bit) & 1) == 0) set += exprs(bit)
+      if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1)
       bit -= 1
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b970eee..d34d377 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -396,6 +396,52 @@ object functions extends LegacyFunctions {
     */
   def first(columnName: String): Column = first(Column(columnName))
 
+
+  /**
+    * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
+    * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+    *
+    * @group agg_funcs
+    * @since 2.0.0
+    */
+  def grouping(e: Column): Column = Column(Grouping(e.expr))
+
+  /**
+    * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
+    * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+    *
+    * @group agg_funcs
+    * @since 2.0.0
+    */
+  def grouping(columnName: String): Column = grouping(Column(columnName))
+
+  /**
+    * Aggregate function: returns the level of grouping, equals to
+    *
+    *   (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
+    *
+    * Note: the list of columns should match with grouping columns exactly, or empty (means all the
+    * grouping columns).
+    *
+    * @group agg_funcs
+    * @since 2.0.0
+    */
+  def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr)))
+
+  /**
+    * Aggregate function: returns the level of grouping, equals to
+    *
+    *   (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
+    *
+    * Note: the list of columns should match with grouping columns exactly.
+    *
+    * @group agg_funcs
+    * @since 2.0.0
+    */
+  def grouping_id(colName: String, colNames: String*): Column = {
+    grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*)
+  }
+
   /**
    * Aggregate function: returns the kurtosis of the values in a group.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 08fb7c9..78bf6c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.DecimalType
@@ -98,6 +99,49 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
     assert(cube0.where("date IS NULL").count > 0)
   }
 
+  test("grouping and grouping_id") {
+    checkAnswer(
+      courseSales.cube("course", "year")
+        .agg(grouping("course"), grouping("year"), grouping_id("course", "year")),
+      Row("Java", 2012, 0, 0, 0) ::
+        Row("Java", 2013, 0, 0, 0) ::
+        Row("Java", null, 0, 1, 1) ::
+        Row("dotNET", 2012, 0, 0, 0) ::
+        Row("dotNET", 2013, 0, 0, 0) ::
+        Row("dotNET", null, 0, 1, 1) ::
+        Row(null, 2012, 1, 0, 2) ::
+        Row(null, 2013, 1, 0, 2) ::
+        Row(null, null, 1, 1, 3) :: Nil
+    )
+
+    intercept[AnalysisException] {
+      courseSales.groupBy().agg(grouping("course")).explain()
+    }
+    intercept[AnalysisException] {
+      courseSales.groupBy().agg(grouping_id("course")).explain()
+    }
+  }
+
+  test("grouping/grouping_id inside window function") {
+
+    val w = Window.orderBy(sum("earnings"))
+    checkAnswer(
+      courseSales.cube("course", "year")
+        .agg(sum("earnings"),
+          grouping_id("course", "year"),
+          rank().over(Window.partitionBy(grouping_id("course", "year")).orderBy(sum("earnings")))),
+      Row("Java", 2012, 20000.0, 0, 2) ::
+        Row("Java", 2013, 30000.0, 0, 3) ::
+        Row("Java", null, 50000.0, 1, 1) ::
+        Row("dotNET", 2012, 15000.0, 0, 1) ::
+        Row("dotNET", 2013, 48000.0, 0, 4) ::
+        Row("dotNET", null, 63000.0, 1, 2) ::
+        Row(null, 2012, 35000.0, 2, 1) ::
+        Row(null, 2013, 78000.0, 2, 2) ::
+        Row(null, null, 113000.0, 3, 1) :: Nil
+    )
+  }
+
   test("rollup overlapping columns") {
     checkAnswer(
       testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 8ef7b61..f665a1c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2055,6 +2055,56 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     )
   }
 
+  test("grouping sets") {
+    checkAnswer(
+      sql("select course, year, sum(earnings) from courseSales group by course, year " +
+        "grouping sets(course, year)"),
+      Row("Java", null, 50000.0) ::
+        Row("dotNET", null, 63000.0) ::
+        Row(null, 2012, 35000.0) ::
+        Row(null, 2013, 78000.0) :: Nil
+    )
+
+    checkAnswer(
+      sql("select course, year, sum(earnings) from courseSales group by course, year " +
+        "grouping sets(course)"),
+      Row("Java", null, 50000.0) ::
+        Row("dotNET", null, 63000.0) :: Nil
+    )
+
+    checkAnswer(
+      sql("select course, year, sum(earnings) from courseSales group by course, year " +
+        "grouping sets(year)"),
+      Row(null, 2012, 35000.0) ::
+        Row(null, 2013, 78000.0) :: Nil
+    )
+  }
+
+  test("grouping and grouping_id") {
+    checkAnswer(
+      sql("select course, year, grouping(course), grouping(year), grouping_id(course, year)" +
+        " from courseSales group by cube(course, year)"),
+      Row("Java", 2012, 0, 0, 0) ::
+        Row("Java", 2013, 0, 0, 0) ::
+        Row("Java", null, 0, 1, 1) ::
+        Row("dotNET", 2012, 0, 0, 0) ::
+        Row("dotNET", 2013, 0, 0, 0) ::
+        Row("dotNET", null, 0, 1, 1) ::
+        Row(null, 2012, 1, 0, 2) ::
+        Row(null, 2013, 1, 0, 2) ::
+        Row(null, null, 1, 1, 3) :: Nil
+    )
+
+    var error = intercept[AnalysisException] {
+      sql("select course, year, grouping(course) from courseSales group by course, year")
+    }
+    assert(error.getMessage contains "grouping() can only be used with GroupingSets/Cube/Rollup")
+    error = intercept[AnalysisException] {
+      sql("select course, year, grouping_id(course, year) from courseSales group by course, year")
+    }
+    assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup")
+  }
+
   test("SPARK-13056: Null in map value causes NPE") {
     val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value")
     withTempTable("maptest") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 61b73fa..9097c1a 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -328,6 +328,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     // Hive returns null rather than NaN when n = 1
     "udaf_covar_samp",
 
+    // The implementation of GROUPING__ID in Hive is wrong (not match with doc).
+    "groupby_grouping_id1",
+    "groupby_grouping_id2",
+    "groupby_grouping_sets1",
+
     // Spark parser treats numerical literals differently: it creates decimals instead of doubles.
     "udf_abs",
     "udf_format_number",
@@ -503,9 +508,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     "groupby11",
     "groupby12",
     "groupby1_limit",
-    "groupby_grouping_id1",
-    "groupby_grouping_id2",
-    "groupby_grouping_sets1",
     "groupby_grouping_sets2",
     "groupby_grouping_sets3",
     "groupby_grouping_sets4",

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864
index dac1b84..c066aee 100644
--- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864	
+++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864	
@@ -1,6 +1,6 @@
-500	NULL	0
-91	0	1
-84	1	1
-105	2	1
-113	3	1
-107	4	1
+500	NULL	1
+91	0	0
+84	1	0
+105	2	0
+113	3	0
+107	4	0

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a
index dac1b84..c066aee 100644
--- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a	
+++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a	
@@ -1,6 +1,6 @@
-500	NULL	0
-91	0	1
-84	1	1
-105	2	1
-113	3	1
-107	4	1
+500	NULL	1
+91	0	0
+84	1	0
+105	2	0
+113	3	0
+107	4	0

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89
index 1eea4a9..fcacbe3 100644
--- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89	
+++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89	
@@ -1,10 +1,10 @@
-1	0	5	3
-1	0	15	3
-1	0	25	3
-1	0	60	3
-1	0	75	3
-1	0	80	3
-1	0	100	3
-1	0	140	3
-1	0	145	3
-1	0	150	3
+1	0	5	0
+1	0	15	0
+1	0	25	0
+1	0	60	0
+1	0	75	0
+1	0	80	0
+1	0	100	0
+1	0	140	0
+1	0	145	0
+1	0	150	0

http://git-wip-us.apache.org/repos/asf/spark/blob/b5761d15/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce
index 1eea4a9..fcacbe3 100644
--- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce	
+++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce	
@@ -1,10 +1,10 @@
-1	0	5	3
-1	0	15	3
-1	0	25	3
-1	0	60	3
-1	0	75	3
-1	0	80	3
-1	0	100	3
-1	0	140	3
-1	0	145	3
-1	0	150	3
+1	0	5	0
+1	0	15	0
+1	0	25	0
+1	0	60	0
+1	0	75	0
+1	0	80	0
+1	0	100	0
+1	0	140	0
+1	0	145	0
+1	0	150	0


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org