You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2016/03/05 12:25:12 UTC

spark git commit: [SPARK-12720][SQL] SQL Generation Support for Cube, Rollup, and Grouping Sets

Repository: spark
Updated Branches:
  refs/heads/master f19228eed -> adce5ee72


[SPARK-12720][SQL] SQL Generation Support for Cube, Rollup, and Grouping Sets

#### What changes were proposed in this pull request?

This PR is for supporting SQL generation for cube, rollup and grouping sets.

For example, a query using rollup:
```SQL
SELECT count(*) as cnt, key % 5, grouping_id() FROM t1 GROUP BY key % 5 WITH ROLLUP
```
Original logical plan:
```
  Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46],
            [(count(1),mode=Complete,isDistinct=false) AS cnt#43L,
             (key#17L % cast(5 as bigint))#47L AS _c1#45L,
             grouping__id#46 AS _c2#44]
  +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0),
             List(key#17L, value#18, null, 1)],
            [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46]
     +- Project [key#17L,
                 value#18,
                 (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L]
        +- Subquery t1
           +- Relation[key#17L,value#18] ParquetRelation
```
Converted SQL:
```SQL
  SELECT count( 1) AS `cnt`,
         (`t1`.`key` % CAST(5 AS BIGINT)),
         grouping_id() AS `_c2`
  FROM `default`.`t1`
  GROUP BY (`t1`.`key` % CAST(5 AS BIGINT))
  GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ())
```

#### How was the this patch tested?

Added eight test cases in `LogicalPlanToSQLSuite`.

Author: gatorsmile <ga...@gmail.com>
Author: xiaoli <li...@gmail.com>
Author: Xiao Li <xi...@Xiaos-MacBook-Pro.local>

Closes #11283 from gatorsmile/groupingSetsToSQL.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/adce5ee7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/adce5ee7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/adce5ee7

Branch: refs/heads/master
Commit: adce5ee721c6a844ff21dfcd8515859458fe611d
Parents: f19228e
Author: gatorsmile <ga...@gmail.com>
Authored: Sat Mar 5 19:25:03 2016 +0800
Committer: Cheng Lian <li...@databricks.com>
Committed: Sat Mar 5 19:25:03 2016 +0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  14 +-
 .../sql/catalyst/expressions/grouping.scala     |   1 +
 .../org/apache/spark/sql/hive/SQLBuilder.scala  |  76 +++++++++-
 .../spark/sql/hive/LogicalPlanToSQLSuite.scala  | 143 +++++++++++++++++++
 4 files changed, 226 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/adce5ee7/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 92e724f..88924e2 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -348,13 +348,13 @@ def grouping_id(*cols):
     grouping columns).
 
     >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()
-    +-----+------------+--------+
-    | name|groupingid()|sum(age)|
-    +-----+------------+--------+
-    | null|           1|       7|
-    |Alice|           0|       2|
-    |  Bob|           0|       5|
-    +-----+------------+--------+
+    +-----+-------------+--------+
+    | name|grouping_id()|sum(age)|
+    +-----+-------------+--------+
+    | null|            1|       7|
+    |Alice|            0|       2|
+    |  Bob|            0|       5|
+    +-----+-------------+--------+
     """
     sc = SparkContext._active_spark_context
     jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column))

http://git-wip-us.apache.org/repos/asf/spark/blob/adce5ee7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
index a204060..437e417 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
@@ -63,4 +63,5 @@ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Une
   override def children: Seq[Expression] = groupByExprs
   override def dataType: DataType = IntegerType
   override def nullable: Boolean = false
+  override def prettyName: String = "grouping_id"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/adce5ee7/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index 9a14ccf..8d411a9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
 import org.apache.spark.sql.catalyst.util.quoteIdentifier
 import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.types.{DataType, NullType}
+import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType}
 
 /**
  * A place holder for generated SQL for subquery expression.
@@ -118,6 +118,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
     case p: Project =>
       projectToSQL(p, isDistinct = false)
 
+    case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) =>
+      groupingSetToSQL(a, e, p)
+
     case p: Aggregate =>
       aggregateToSQL(p)
 
@@ -244,6 +247,77 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
     )
   }
 
+  private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
+    output1.size == output2.size &&
+      output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
+
+  private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
+    assert(a.child == e && e.child == p)
+    a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
+      sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
+  }
+
+  private def groupingSetToSQL(
+      agg: Aggregate,
+      expand: Expand,
+      project: Project): String = {
+    assert(agg.groupingExpressions.length > 1)
+
+    // The last column of Expand is always grouping ID
+    val gid = expand.output.last
+
+    val numOriginalOutput = project.child.output.length
+    // Assumption: Aggregate's groupingExpressions is composed of
+    // 1) the attributes of aliased group by expressions
+    // 2) gid, which is always the last one
+    val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
+    // Assumption: Project's projectList is composed of
+    // 1) the original output (Project's child.output),
+    // 2) the aliased group by expressions.
+    val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
+    val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
+
+    // a map from group by attributes to the original group by expressions.
+    val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
+
+    val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
+      // Assumption: expand.projections is composed of
+      // 1) the original output (Project's child.output),
+      // 2) group by attributes(or null literal)
+      // 3) gid, which is always the last one in each project in Expand
+      project.drop(numOriginalOutput).dropRight(1).collect {
+        case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
+      }
+    }
+    val groupingSetSQL =
+      "GROUPING SETS(" +
+        groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"
+
+    val aggExprs = agg.aggregateExpressions.map { case expr =>
+      expr.transformDown {
+        // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
+        case ar: AttributeReference if ar == gid => GroupingID(Nil)
+        case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar)
+        case a @ Cast(BitwiseAnd(
+            ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)),
+            Literal(1, IntegerType)), ByteType) if ar == gid =>
+          // for converting an expression to its original SQL format grouping(col)
+          val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
+          groupByExprs.lift(idx).map(Grouping).getOrElse(a)
+      }
+    }
+
+    build(
+      "SELECT",
+      aggExprs.map(_.sql).mkString(", "),
+      if (agg.child == OneRowRelation) "" else "FROM",
+      toSQL(project.child),
+      "GROUP BY",
+      groupingSQL,
+      groupingSetSQL
+    )
+  }
+
   object Canonicalizer extends RuleExecutor[LogicalPlan] {
     override protected def batches: Seq[Batch] = Seq(
       Batch("Canonicalizer", FixedPoint(100),

http://git-wip-us.apache.org/repos/asf/spark/blob/adce5ee7/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
index d708fcf..f457d43 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -218,6 +218,149 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
     checkHiveQl("SELECT DISTINCT id FROM parquet_t0")
   }
 
+  test("rollup/cube #1") {
+    // Original logical plan:
+    //   Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46],
+    //             [(count(1),mode=Complete,isDistinct=false) AS cnt#43L,
+    //              (key#17L % cast(5 as bigint))#47L AS _c1#45L,
+    //              grouping__id#46 AS _c2#44]
+    //   +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0),
+    //              List(key#17L, value#18, null, 1)],
+    //             [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46]
+    //      +- Project [key#17L,
+    //                  value#18,
+    //                  (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L]
+    //         +- Subquery t1
+    //            +- Relation[key#17L,value#18] ParquetRelation
+    // Converted SQL:
+    //   SELECT count( 1) AS `cnt`,
+    //          (`t1`.`key` % CAST(5 AS BIGINT)),
+    //          grouping_id() AS `_c2`
+    //   FROM `default`.`t1`
+    //   GROUP BY (`t1`.`key` % CAST(5 AS BIGINT))
+    //   GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ())
+    checkHiveQl(
+      "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP")
+    checkHiveQl(
+      "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE")
+  }
+
+  test("rollup/cube #2") {
+    checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP")
+    checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE")
+  }
+
+  test("rollup/cube #3") {
+    checkHiveQl(
+      "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP")
+    checkHiveQl(
+      "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE")
+  }
+
+  test("rollup/cube #4") {
+    checkHiveQl(
+      s"""
+        |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
+        |GROUP BY key % 5, key - 5 WITH ROLLUP
+      """.stripMargin)
+    checkHiveQl(
+      s"""
+        |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
+        |GROUP BY key % 5, key - 5 WITH CUBE
+      """.stripMargin)
+  }
+
+  test("rollup/cube #5") {
+    checkHiveQl(
+      s"""
+        |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
+        |FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5
+        |WITH ROLLUP
+      """.stripMargin)
+    checkHiveQl(
+      s"""
+        |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
+        |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
+        |WITH CUBE
+      """.stripMargin)
+  }
+
+  test("rollup/cube #6") {
+    checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b")
+    checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b")
+    checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b")
+    checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b")
+    checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP")
+    checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE")
+  }
+
+  test("rollup/cube #7") {
+    checkHiveQl("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)")
+    checkHiveQl("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)")
+    checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)")
+  }
+
+  test("rollup/cube #8") {
+    // grouping_id() is part of another expression
+    checkHiveQl(
+      s"""
+         |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
+         |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
+         |WITH ROLLUP
+      """.stripMargin)
+    checkHiveQl(
+      s"""
+         |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
+         |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
+         |WITH CUBE
+      """.stripMargin)
+  }
+
+  test("rollup/cube #9") {
+    // self join is used as the child node of ROLLUP/CUBE with replaced quantifiers
+    checkHiveQl(
+      s"""
+         |SELECT t.key - 5, cnt, SUM(cnt)
+         |FROM (SELECT x.key, COUNT(*) as cnt
+         |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
+         |GROUP BY cnt, t.key - 5
+         |WITH ROLLUP
+      """.stripMargin)
+    checkHiveQl(
+      s"""
+         |SELECT t.key - 5, cnt, SUM(cnt)
+         |FROM (SELECT x.key, COUNT(*) as cnt
+         |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
+         |GROUP BY cnt, t.key - 5
+         |WITH CUBE
+      """.stripMargin)
+  }
+
+  test("grouping sets #1") {
+    checkHiveQl(
+      s"""
+         |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3
+         |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
+         |GROUPING SETS (key % 5, key - 5)
+      """.stripMargin)
+  }
+
+  test("grouping sets #2") {
+    checkHiveQl(
+      "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b")
+    checkHiveQl(
+      "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b")
+    checkHiveQl(
+      "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b")
+    checkHiveQl(
+      "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b")
+    checkHiveQl(
+      s"""
+         |SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b
+         |GROUPING SETS ((), (a), (a, b)) ORDER BY a, b
+      """.stripMargin)
+  }
+
   test("cluster by") {
     checkHiveQl("SELECT id FROM parquet_t0 CLUSTER BY id")
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org