You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/01/04 23:27:02 UTC

spark git commit: [SPARK-12541] [SQL] support cube/rollup as function

Repository: spark
Updated Branches:
  refs/heads/master 93ef9b6a2 -> d084a2de3


[SPARK-12541] [SQL] support cube/rollup as function

This PR enable cube/rollup as function, so they can be used as this:
```
select a, b, sum(c) from t group by rollup(a, b)
```

Author: Davies Liu <da...@databricks.com>

Closes #10522 from davies/rollup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d084a2de
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d084a2de
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d084a2de

Branch: refs/heads/master
Commit: d084a2de3271fd8b0d29ee67e1e214e8b9d96d6d
Parents: 93ef9b6
Author: Davies Liu <da...@databricks.com>
Authored: Mon Jan 4 14:26:56 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Jan 4 14:26:56 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 10 ++---
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  2 +-
 .../catalyst/analysis/FunctionRegistry.scala    |  4 ++
 .../sql/catalyst/expressions/grouping.scala     | 43 ++++++++++++++++++++
 .../catalyst/plans/logical/basicOperators.scala | 37 -----------------
 .../org/apache/spark/sql/GroupedData.scala      |  6 +--
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 29 +++++++++++++
 .../org/apache/spark/sql/hive/HiveQl.scala      |  4 +-
 8 files changed, 87 insertions(+), 48 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d084a2de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c396546..06efcd4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.trees.TreeNodeRef
-import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf}
+import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
 import org.apache.spark.sql.types._
 
 /**
@@ -208,10 +208,10 @@ class Analyzer(
 
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
       case a if !a.childrenResolved => a // be sure all of the children are resolved.
-      case a: Cube =>
-        GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
-      case a: Rollup =>
-        GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
+      case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
+        GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
+      case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
+        GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
       case x: GroupingSets =>
         val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d084a2de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index a1be147..2a2e0d2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d084a2de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 12c24cc..57d1a11 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -285,6 +285,10 @@ object FunctionRegistry {
     expression[InputFileName]("input_file_name"),
     expression[MonotonicallyIncreasingID]("monotonically_increasing_id"),
 
+    // grouping sets
+    expression[Cube]("cube"),
+    expression[Rollup]("rollup"),
+
     // window functions
     expression[Lead]("lead"),
     expression[Lag]("lag"),

http://git-wip-us.apache.org/repos/asf/spark/blob/d084a2de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
new file mode 100644
index 0000000..2997ee8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.types._
+
+/**
+  * A placeholder expression for cube/rollup, which will be replaced by analyzer
+  */
+trait GroupingSet extends Expression with CodegenFallback {
+
+  def groupByExprs: Seq[Expression]
+  override def children: Seq[Expression] = groupByExprs
+
+  // this should be replaced first
+  override lazy val resolved: Boolean = false
+
+  override def dataType: DataType = throw new UnsupportedOperationException
+  override def foldable: Boolean = false
+  override def nullable: Boolean = true
+  override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+}
+
+case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {}
+
+case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {}

http://git-wip-us.apache.org/repos/asf/spark/blob/d084a2de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 5f34d4a..986062e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -397,43 +397,6 @@ case class GroupingSets(
     this.copy(aggregations = aggs)
 }
 
-/**
- * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
- * and eventually will be transformed to Aggregate(.., Expand) in Analyzer
- *
- * @param groupByExprs The Group By expressions candidates.
- * @param child        Child operator
- * @param aggregations The Aggregation expressions, those non selected group by expressions
- *                     will be considered as constant null if it appears in the expressions
- */
-case class Cube(
-    groupByExprs: Seq[Expression],
-    child: LogicalPlan,
-    aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
-
-  def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
-    this.copy(aggregations = aggs)
-}
-
-/**
- * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
- * and eventually will be transformed to Aggregate(.., Expand) in Analyzer
- *
- * @param groupByExprs The Group By expressions candidates, take effective only if the
- *                     associated bit in the bitmask set to 1.
- * @param child        Child operator
- * @param aggregations The Aggregation expressions, those non selected group by expressions
- *                     will be considered as constant null if it appears in the expressions
- */
-case class Rollup(
-    groupByExprs: Seq[Expression],
-    child: LogicalPlan,
-    aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
-
-  def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
-    this.copy(aggregations = aggs)
-}
-
 case class Pivot(
     groupByExprs: Seq[NamedExpression],
     pivotColumn: Expression,

http://git-wip-us.apache.org/repos/asf/spark/blob/d084a2de/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 13341a8..2aa82f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
+import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Aggregate}
 import org.apache.spark.sql.types.NumericType
 
 
@@ -58,10 +58,10 @@ class GroupedData protected[sql](
           df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
       case GroupedData.RollupType =>
         DataFrame(
-          df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
+          df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan))
       case GroupedData.CubeType =>
         DataFrame(
-          df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
+          df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan))
       case GroupedData.PivotType(pivotCol, values) =>
         val aliasedGrps = groupingExprs.map(alias)
         DataFrame(

http://git-wip-us.apache.org/repos/asf/spark/blob/d084a2de/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index bb82b56..115b617 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2028,4 +2028,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       Row(false) :: Row(true) :: Nil)
   }
 
+  test("rollup") {
+    checkAnswer(
+      sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" +
+        " order by course, year"),
+      Row(null, null, 113000.0) ::
+        Row("Java", null, 50000.0) ::
+        Row("Java", 2012, 20000.0) ::
+        Row("Java", 2013, 30000.0) ::
+        Row("dotNET", null, 63000.0) ::
+        Row("dotNET", 2012, 15000.0) ::
+        Row("dotNET", 2013, 48000.0) :: Nil
+    )
+  }
+
+  test("cube") {
+    checkAnswer(
+      sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"),
+      Row("Java", 2012, 20000.0) ::
+        Row("Java", 2013, 30000.0) ::
+        Row("Java", null, 50000.0) ::
+        Row("dotNET", 2012, 15000.0) ::
+        Row("dotNET", 2013, 48000.0) ::
+        Row("dotNET", null, 63000.0) ::
+        Row(null, 2012, 35000.0) ::
+        Row(null, 2013, 78000.0) ::
+        Row(null, null, 113000.0) :: Nil
+    )
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d084a2de/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index b1d841d..cbfe09b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1121,12 +1121,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
             }),
             rollupGroupByClause.map(e => e match {
               case Token("TOK_ROLLUP_GROUPBY", children) =>
-                Rollup(children.map(nodeToExpr), withLateralView, selectExpressions)
+                Aggregate(Seq(Rollup(children.map(nodeToExpr))), selectExpressions, withLateralView)
               case _ => sys.error("Expect WITH ROLLUP")
             }),
             cubeGroupByClause.map(e => e match {
               case Token("TOK_CUBE_GROUPBY", children) =>
-                Cube(children.map(nodeToExpr), withLateralView, selectExpressions)
+                Aggregate(Seq(Cube(children.map(nodeToExpr))), selectExpressions, withLateralView)
               case _ => sys.error("Expect WITH CUBE")
             }),
             Some(Project(selectExpressions, withLateralView))).flatten.head


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org