You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/06/23 19:52:25 UTC

spark git commit: [SPARK-7235] [SQL] Refactor the grouping sets

Repository: spark
Updated Branches:
  refs/heads/master 4f7fbefb8 -> 7b1450b66


[SPARK-7235] [SQL] Refactor the grouping sets

The logical plan `Expand` takes the `output` as constructor argument, which break the references chain. We need to refactor the code, as well as the column pruning.

Author: Cheng Hao <ha...@intel.com>

Closes #5780 from chenghao-intel/expand and squashes the following commits:

76e4aa4 [Cheng Hao] revert the change for case insenstive
7c10a83 [Cheng Hao] refactor the grouping sets


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7b1450b6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7b1450b6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7b1450b6

Branch: refs/heads/master
Commit: 7b1450b666f88452e7fe969a6d59e8b24842ea39
Parents: 4f7fbef
Author: Cheng Hao <ha...@intel.com>
Authored: Tue Jun 23 10:52:17 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Jun 23 10:52:17 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 55 +++----------
 .../catalyst/expressions/namedExpressions.scala |  2 +-
 .../sql/catalyst/optimizer/Optimizer.scala      |  4 +
 .../catalyst/plans/logical/basicOperators.scala | 84 +++++++++++++++-----
 .../spark/sql/execution/SparkStrategies.scala   |  4 +-
 5 files changed, 78 insertions(+), 71 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7b1450b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6311784..0a3f5a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -192,49 +192,17 @@ class Analyzer(
       Seq.tabulate(1 << c.groupByExprs.length)(i => i)
     }
 
-    /**
-     * Create an array of Projections for the child projection, and replace the projections'
-     * expressions which equal GroupBy expressions with Literal(null), if those expressions
-     * are not set for this grouping set (according to the bit mask).
-     */
-    private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = {
-      val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
-
-      g.bitmasks.foreach { bitmask =>
-        // get the non selected grouping attributes according to the bit mask
-        val nonSelectedGroupExprs = ArrayBuffer.empty[Expression]
-        var bit = g.groupByExprs.length - 1
-        while (bit >= 0) {
-          if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit)
-          bit -= 1
-        }
-
-        val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown {
-          case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined =>
-            // if the input attribute in the Invalid Grouping Expression set of for this group
-            // replace it with constant null
-            Literal.create(null, expr.dataType)
-          case x if x == g.gid =>
-            // replace the groupingId with concrete value (the bit mask)
-            Literal.create(bitmask, IntegerType)
-        })
-
-        result += substitution
-      }
-
-      result.toSeq
-    }
-
     def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-      case a: Cube if a.resolved =>
-        GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
-      case a: Rollup if a.resolved =>
-        GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations, a.gid)
-      case x: GroupingSets if x.resolved =>
+      case a: Cube =>
+        GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
+      case a: Rollup =>
+        GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
+      case x: GroupingSets =>
+        val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
         Aggregate(
-          x.groupByExprs :+ x.gid,
+          x.groupByExprs :+ VirtualColumn.groupingIdAttribute,
           x.aggregations,
-          Expand(expand(x), x.child.output :+ x.gid, x.child))
+          Expand(x.bitmasks, x.groupByExprs, gid, x.child))
     }
   }
 
@@ -368,12 +336,7 @@ class Analyzer(
 
       case q: LogicalPlan =>
         logTrace(s"Attempting to resolve ${q.simpleString}")
-        q transformExpressionsUp {
-          case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
-            resolver(nameParts(0), VirtualColumn.groupingIdName) &&
-            q.isInstanceOf[GroupingAnalytics] =>
-            // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
-            q.asInstanceOf[GroupingAnalytics].gid
+        q transformExpressionsUp  {
           case u @ UnresolvedAttribute(nameParts) =>
             // Leave unchanged if resolution fails.  Hopefully will be resolved next round.
             val result =

http://git-wip-us.apache.org/repos/asf/spark/blob/7b1450b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 58dbeaf..9cacdce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -262,5 +262,5 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E
 
 object VirtualColumn {
   val groupingIdName: String = "grouping__id"
-  def newGroupingId: AttributeReference = AttributeReference(groupingIdName, IntegerType, false)()
+  val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7b1450b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9132a78..98b4476 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -121,6 +121,10 @@ object UnionPushdown extends Rule[LogicalPlan] {
  */
 object ColumnPruning extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child))
+      if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty =>
+      a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references)))
+
     // Eliminate attributes that are not needed to calculate the specified aggregates.
     case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
       a.copy(child = Project(a.references.toSeq, child))

http://git-wip-us.apache.org/repos/asf/spark/blob/7b1450b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 7814e51..fae3398 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.OpenHashSet
 
 case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = projectList.map(_.toAttribute)
@@ -228,24 +229,76 @@ case class Window(
 /**
  * Apply the all of the GroupExpressions to every input row, hence we will get
  * multiple output rows for a input row.
- * @param projections The group of expressions, all of the group expressions should
- *                    output the same schema specified by the parameter `output`
- * @param output      The output Schema
+ * @param bitmasks The bitmask set represents the grouping sets
+ * @param groupByExprs The grouping by expressions
  * @param child       Child operator
  */
 case class Expand(
-    projections: Seq[Seq[Expression]],
-    output: Seq[Attribute],
+    bitmasks: Seq[Int],
+    groupByExprs: Seq[Expression],
+    gid: Attribute,
     child: LogicalPlan) extends UnaryNode {
   override def statistics: Statistics = {
     val sizeInBytes = child.statistics.sizeInBytes * projections.length
     Statistics(sizeInBytes = sizeInBytes)
   }
+
+  val projections: Seq[Seq[Expression]] = expand()
+
+  /**
+   * Extract attribute set according to the grouping id
+   * @param bitmask bitmask to represent the selected of the attribute sequence
+   * @param exprs the attributes in sequence
+   * @return the attributes of non selected specified via bitmask (with the bit set to 1)
+   */
+  private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression])
+  : OpenHashSet[Expression] = {
+    val set = new OpenHashSet[Expression](2)
+
+    var bit = exprs.length - 1
+    while (bit >= 0) {
+      if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit))
+      bit -= 1
+    }
+
+    set
+  }
+
+  /**
+   * Create an array of Projections for the child projection, and replace the projections'
+   * expressions which equal GroupBy expressions with Literal(null), if those expressions
+   * are not set for this grouping set (according to the bit mask).
+   */
+  private[this] def expand(): Seq[Seq[Expression]] = {
+    val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
+
+    bitmasks.foreach { bitmask =>
+      // get the non selected grouping attributes according to the bit mask
+      val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
+
+      val substitution = (child.output :+ gid).map(expr => expr transformDown {
+        case x: Expression if nonSelectedGroupExprSet.contains(x) =>
+          // if the input attribute in the Invalid Grouping Expression set of for this group
+          // replace it with constant null
+          Literal.create(null, expr.dataType)
+        case x if x == gid =>
+          // replace the groupingId with concrete value (the bit mask)
+          Literal.create(bitmask, IntegerType)
+      })
+
+      result += substitution
+    }
+
+    result.toSeq
+  }
+
+  override def output: Seq[Attribute] = {
+    child.output :+ gid
+  }
 }
 
 trait GroupingAnalytics extends UnaryNode {
   self: Product =>
-  def gid: AttributeReference
   def groupByExprs: Seq[Expression]
   def aggregations: Seq[NamedExpression]
 
@@ -266,17 +319,12 @@ trait GroupingAnalytics extends UnaryNode {
  * @param child        Child operator
  * @param aggregations The Aggregation expressions, those non selected group by expressions
  *                     will be considered as constant null if it appears in the expressions
- * @param gid          The attribute represents the virtual column GROUPING__ID, and it's also
- *                     the bitmask indicates the selected GroupBy Expressions for each
- *                     aggregating output row.
- *                     The associated output will be one of the value in `bitmasks`
  */
 case class GroupingSets(
     bitmasks: Seq[Int],
     groupByExprs: Seq[Expression],
     child: LogicalPlan,
-    aggregations: Seq[NamedExpression],
-    gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+    aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
 
   def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
     this.copy(aggregations = aggs)
@@ -290,15 +338,11 @@ case class GroupingSets(
  * @param child        Child operator
  * @param aggregations The Aggregation expressions, those non selected group by expressions
  *                     will be considered as constant null if it appears in the expressions
- * @param gid          The attribute represents the virtual column GROUPING__ID, and it's also
- *                     the bitmask indicates the selected GroupBy Expressions for each
- *                     aggregating output row.
  */
 case class Cube(
     groupByExprs: Seq[Expression],
     child: LogicalPlan,
-    aggregations: Seq[NamedExpression],
-    gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+    aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
 
   def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
     this.copy(aggregations = aggs)
@@ -313,15 +357,11 @@ case class Cube(
  * @param child        Child operator
  * @param aggregations The Aggregation expressions, those non selected group by expressions
  *                     will be considered as constant null if it appears in the expressions
- * @param gid          The attribute represents the virtual column GROUPING__ID, and it's also
- *                     the bitmask indicates the selected GroupBy Expressions for each
- *                     aggregating output row.
  */
 case class Rollup(
     groupByExprs: Seq[Expression],
     child: LogicalPlan,
-    aggregations: Seq[NamedExpression],
-    gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+    aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
 
   def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
     this.copy(aggregations = aggs)

http://git-wip-us.apache.org/repos/asf/spark/blob/7b1450b6/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5c420eb..1ff1cc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -308,8 +308,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.Project(projectList, planLater(child)) :: Nil
       case logical.Filter(condition, child) =>
         execution.Filter(condition, planLater(child)) :: Nil
-      case logical.Expand(projections, output, child) =>
-        execution.Expand(projections, output, planLater(child)) :: Nil
+      case e @ logical.Expand(_, _, _, child) =>
+        execution.Expand(e.projections, e.output, planLater(child)) :: Nil
       case logical.Aggregate(group, agg, child) =>
         execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
       case logical.Window(projectList, windowExpressions, spec, child) =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org