You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/03 04:47:49 UTC

spark git commit: [SPARK-12951] [SQL] support spilling in generated aggregate

Repository: spark
Updated Branches:
  refs/heads/master ff71261b6 -> 99a6e3c1e


[SPARK-12951] [SQL] support spilling in generated aggregate

This PR add spilling support for generated TungstenAggregate.

If spilling happened, it's not that bad to do the iterator based sort-merge-aggregate (not generated).

The changes will be covered by TungstenAggregationQueryWithControlledFallbackSuite

Author: Davies Liu <da...@databricks.com>

Closes #10998 from davies/gen_spilling.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/99a6e3c1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/99a6e3c1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/99a6e3c1

Branch: refs/heads/master
Commit: 99a6e3c1e8d580ce1cc497bd9362eaf16c597f77
Parents: ff71261
Author: Davies Liu <da...@databricks.com>
Authored: Tue Feb 2 19:47:44 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Tue Feb 2 19:47:44 2016 -0800

----------------------------------------------------------------------
 .../execution/aggregate/TungstenAggregate.scala | 172 +++++++++++++++----
 1 file changed, 142 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/99a6e3c1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index a8a81d6..f61db85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.{DecimalType, StructType}
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.unsafe.KVIterator
 
 case class TungstenAggregate(
@@ -258,6 +258,7 @@ case class TungstenAggregate(
 
   // The name for HashMap
   private var hashMapTerm: String = _
+  private var sorterTerm: String = _
 
   /**
     * This is called by generated Java class, should be public.
@@ -286,39 +287,98 @@ case class TungstenAggregate(
     GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
   }
 
-
   /**
-    * Update peak execution memory, called in generated Java class.
+    * Called by generated Java class to finish the aggregate and return a KVIterator.
     */
-  def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = {
+  def finishAggregate(
+      hashMap: UnsafeFixedWidthAggregationMap,
+      sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = {
+
+    // update peak execution memory
     val mapMemory = hashMap.getPeakMemoryUsedBytes
+    val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
+    val peakMemory = Math.max(mapMemory, sorterMemory)
     val metrics = TaskContext.get().taskMetrics()
-    metrics.incPeakExecutionMemory(mapMemory)
-  }
+    metrics.incPeakExecutionMemory(peakMemory)
 
-  private def doProduceWithKeys(ctx: CodegenContext): String = {
-    val initAgg = ctx.freshName("initAgg")
-    ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+    if (sorter == null) {
+      // not spilled
+      return hashMap.iterator()
+    }
 
-    // create hashMap
-    val thisPlan = ctx.addReferenceObj("plan", this)
-    hashMapTerm = ctx.freshName("hashMap")
-    val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
-    ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+    // merge the final hashMap into sorter
+    sorter.merge(hashMap.destructAndCreateExternalSorter())
+    hashMap.free()
+    val sortedIter = sorter.sortedIterator()
+
+    // Create a KVIterator based on the sorted iterator.
+    new KVIterator[UnsafeRow, UnsafeRow] {
+
+      // Create a MutableProjection to merge the rows of same key together
+      val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
+      val mergeProjection = newMutableProjection(
+        mergeExpr,
+        bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
+        subexpressionEliminationEnabled)()
+      val joinedRow = new JoinedRow()
+
+      var currentKey: UnsafeRow = null
+      var currentRow: UnsafeRow = null
+      var nextKey: UnsafeRow = if (sortedIter.next()) {
+        sortedIter.getKey
+      } else {
+        null
+      }
 
-    // Create a name for iterator from HashMap
-    val iterTerm = ctx.freshName("mapIter")
-    ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
+      override def next(): Boolean = {
+        if (nextKey != null) {
+          currentKey = nextKey.copy()
+          currentRow = sortedIter.getValue.copy()
+          nextKey = null
+          // use the first row as aggregate buffer
+          mergeProjection.target(currentRow)
+
+          // merge the following rows with same key together
+          var findNextGroup = false
+          while (!findNextGroup && sortedIter.next()) {
+            val key = sortedIter.getKey
+            if (currentKey.equals(key)) {
+              mergeProjection(joinedRow(currentRow, sortedIter.getValue))
+            } else {
+              // We find a new group.
+              findNextGroup = true
+              nextKey = key
+            }
+          }
+
+          true
+        } else {
+          false
+        }
+      }
 
-    // generate code for output
-    val keyTerm = ctx.freshName("aggKey")
-    val bufferTerm = ctx.freshName("aggBuffer")
-    val outputCode = if (modes.contains(Final) || modes.contains(Complete)) {
+      override def getKey: UnsafeRow = currentKey
+      override def getValue: UnsafeRow = currentRow
+      override def close(): Unit = {
+        sortedIter.close()
+      }
+    }
+  }
+
+  /**
+    * Generate the code for output.
+    */
+  private def generateResultCode(
+      ctx: CodegenContext,
+      keyTerm: String,
+      bufferTerm: String,
+      plan: String): String = {
+    if (modes.contains(Final) || modes.contains(Complete)) {
       // generate output using resultExpressions
       ctx.currentVars = null
       ctx.INPUT_ROW = keyTerm
       val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
-          BoundReference(i, e.dataType, e.nullable).gen(ctx)
+        BoundReference(i, e.dataType, e.nullable).gen(ctx)
       }
       ctx.INPUT_ROW = bufferTerm
       val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
@@ -348,7 +408,7 @@ case class TungstenAggregate(
       // This should be the last operator in a stage, we should output UnsafeRow directly
       val joinerTerm = ctx.freshName("unsafeRowJoiner")
       ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
-        s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
+        s"$joinerTerm = $plan.createUnsafeJoiner();")
       val resultRow = ctx.freshName("resultRow")
       s"""
        UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
@@ -367,6 +427,23 @@ case class TungstenAggregate(
        ${consume(ctx, eval)}
        """
     }
+  }
+
+  private def doProduceWithKeys(ctx: CodegenContext): String = {
+    val initAgg = ctx.freshName("initAgg")
+    ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+    // create hashMap
+    val thisPlan = ctx.addReferenceObj("plan", this)
+    hashMapTerm = ctx.freshName("hashMap")
+    val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
+    ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+    sorterTerm = ctx.freshName("sorter")
+    ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
+
+    // Create a name for iterator from HashMap
+    val iterTerm = ctx.freshName("mapIter")
+    ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
 
     val doAgg = ctx.freshName("doAggregateWithKeys")
     ctx.addNewFunction(doAgg,
@@ -374,10 +451,15 @@ case class TungstenAggregate(
         private void $doAgg() throws java.io.IOException {
           ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
 
-          $iterTerm = $hashMapTerm.iterator();
+          $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
         }
        """)
 
+    // generate code for output
+    val keyTerm = ctx.freshName("aggKey")
+    val bufferTerm = ctx.freshName("aggBuffer")
+    val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
+
     s"""
      if (!$initAgg) {
        $initAgg = true;
@@ -391,8 +473,10 @@ case class TungstenAggregate(
        $outputCode
      }
 
-     $thisPlan.updatePeakMemory($hashMapTerm);
-     $hashMapTerm.free();
+     $iterTerm.close();
+     if ($sorterTerm == null) {
+       $hashMapTerm.free();
+     }
      """
   }
 
@@ -425,14 +509,42 @@ case class TungstenAggregate(
       ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
     }
 
+    val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) {
+      val countTerm = ctx.freshName("fallbackCounter")
+      ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
+      (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;")
+    } else {
+      ("true", "", "")
+    }
+
+    // We try to do hash map based in-memory aggregation first. If there is not enough memory (the
+    // hash map will return null for new key), we spill the hash map to disk to free memory, then
+    // continue to do in-memory aggregation and spilling until all the rows had been processed.
+    // Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
     s"""
      // generate grouping key
      ${keyCode.code}
-     UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+     UnsafeRow $buffer = null;
+     if ($checkFallback) {
+       // try to get the buffer from hash map
+       $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+     }
      if ($buffer == null) {
-       // failed to allocate the first page
-       throw new OutOfMemoryError("No enough memory for aggregation");
+       if ($sorterTerm == null) {
+         $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
+       } else {
+         $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
+       }
+       $resetCoulter
+       // the hash map had be spilled, it should have enough memory now,
+       // try  to allocate buffer again.
+       $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+       if ($buffer == null) {
+         // failed to allocate the first page
+         throw new OutOfMemoryError("No enough memory for aggregation");
+       }
      }
+     $incCounter
 
      // evaluate aggregate function
      ${evals.map(_.code).mkString("\n")}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org