You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/10/30 23:47:47 UTC

spark git commit: [SPARK-11423] remove MapPartitionsWithPreparationRDD

Repository: spark
Updated Branches:
  refs/heads/master bb5a2af03 -> 45029bfde


[SPARK-11423] remove MapPartitionsWithPreparationRDD

Since we do not need to preserve a page before calling compute(), MapPartitionsWithPreparationRDD is not needed anymore.

This PR basically revert #8543, #8511, #8038, #8011

Author: Davies Liu <da...@databricks.com>

Closes #9381 from davies/remove_prepare2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/45029bfd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/45029bfd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/45029bfd

Branch: refs/heads/master
Commit: 45029bfdea42eb8964f2ba697859687393d2a558
Parents: bb5a2af
Author: Davies Liu <da...@databricks.com>
Authored: Fri Oct 30 15:47:40 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Fri Oct 30 15:47:40 2015 -0700

----------------------------------------------------------------------
 .../rdd/MapPartitionsWithPreparationRDD.scala   | 66 ----------------
 .../apache/spark/rdd/ZippedPartitionsRDD.scala  | 13 ----
 .../MapPartitionsWithPreparationRDDSuite.scala  | 66 ----------------
 project/MimaExcludes.scala                      |  6 +-
 .../UnsafeFixedWidthAggregationMap.java         |  9 +--
 .../execution/aggregate/TungstenAggregate.scala | 79 +++++++-------------
 .../aggregate/TungstenAggregationIterator.scala | 78 +++++++++----------
 .../org/apache/spark/sql/execution/sort.scala   | 28 ++-----
 8 files changed, 75 insertions(+), 270 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/45029bfd/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
deleted file mode 100644
index 417ff52..0000000
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.ClassTag
-
-import org.apache.spark.{Partition, Partitioner, TaskContext}
-
-/**
- * An RDD that applies a user provided function to every partition of the parent RDD, and
- * additionally allows the user to prepare each partition before computing the parent partition.
- */
-private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag](
-    prev: RDD[T],
-    preparePartition: () => M,
-    executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U],
-    preservesPartitioning: Boolean = false)
-  extends RDD[U](prev) {
-
-  override val partitioner: Option[Partitioner] = {
-    if (preservesPartitioning) firstParent[T].partitioner else None
-  }
-
-  override def getPartitions: Array[Partition] = firstParent[T].partitions
-
-  // In certain join operations, prepare can be called on the same partition multiple times.
-  // In this case, we need to ensure that each call to compute gets a separate prepare argument.
-  private[this] val preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M]
-
-  /**
-   * Prepare a partition for a single call to compute.
-   */
-  def prepare(): Unit = {
-    preparedArguments += preparePartition()
-  }
-
-  /**
-   * Prepare a partition before computing it from its parent.
-   */
-  override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
-    val prepared =
-      if (preparedArguments.isEmpty) {
-        preparePartition()
-      } else {
-        preparedArguments.remove(0)
-      }
-    val parentIterator = firstParent[T].iterator(partition, context)
-    executePartition(context, partition.index, prepared, parentIterator)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/45029bfd/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index 70bf04d..4333a67 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -73,16 +73,6 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
     super.clearDependencies()
     rdds = null
   }
-
-  /**
-   * Call the prepare method of every parent that has one.
-   * This is needed for reserving execution memory in advance.
-   */
-  protected def tryPrepareParents(): Unit = {
-    rdds.collect {
-      case rdd: MapPartitionsWithPreparationRDD[_, _, _] => rdd.prepare()
-    }
-  }
 }
 
 private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
@@ -94,7 +84,6 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
   extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
 
   override def compute(s: Partition, context: TaskContext): Iterator[V] = {
-    tryPrepareParents()
     val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
     f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
   }
@@ -118,7 +107,6 @@ private[spark] class ZippedPartitionsRDD3
   extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
 
   override def compute(s: Partition, context: TaskContext): Iterator[V] = {
-    tryPrepareParents()
     val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
     f(rdd1.iterator(partitions(0), context),
       rdd2.iterator(partitions(1), context),
@@ -146,7 +134,6 @@ private[spark] class ZippedPartitionsRDD4
   extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
 
   override def compute(s: Partition, context: TaskContext): Iterator[V] = {
-    tryPrepareParents()
     val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
     f(rdd1.iterator(partitions(0), context),
       rdd2.iterator(partitions(1), context),

http://git-wip-us.apache.org/repos/asf/spark/blob/45029bfd/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
deleted file mode 100644
index e281e81..0000000
--- a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.collection.mutable
-
-import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext}
-
-class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext {
-
-  test("prepare called before parent partition is computed") {
-    sc = new SparkContext("local", "test")
-
-    // Have the parent partition push a number to the list
-    val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter =>
-      TestObject.things.append(20)
-      iter
-    }
-
-    // Push a different number during the prepare phase
-    val preparePartition = () => { TestObject.things.append(10) }
-
-    // Push yet another number during the execution phase
-    val executePartition = (
-        taskContext: TaskContext,
-        partitionIndex: Int,
-        notUsed: Unit,
-        parentIterator: Iterator[Int]) => {
-      TestObject.things.append(30)
-      TestObject.things.iterator
-    }
-
-    // Verify that the numbers are pushed in the order expected
-    val rdd = new MapPartitionsWithPreparationRDD[Int, Int, Unit](
-      parent, preparePartition, executePartition)
-    val result = rdd.collect()
-    assert(result === Array(10, 20, 30))
-
-    TestObject.things.clear()
-    // Zip two of these RDDs, both should be prepared before the parent is executed
-    val rdd2 = new MapPartitionsWithPreparationRDD[Int, Int, Unit](
-      parent, preparePartition, executePartition)
-    val result2 = rdd.zipPartitions(rdd2)((a, b) => a).collect()
-    assert(result2 === Array(10, 10, 20, 30, 20, 30))
-  }
-
-}
-
-private object TestObject {
-  val things = new mutable.ListBuffer[Int]
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/45029bfd/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index b5e661d..8282f7e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -107,7 +107,11 @@ object MimaExcludes {
           "org.apache.spark.sql.SQLContext.createSession")
       ) ++ Seq(
         ProblemFilters.exclude[MissingMethodProblem](
-          "org.apache.spark.SparkContext.preferredNodeLocationData_=")
+          "org.apache.spark.SparkContext.preferredNodeLocationData_="),
+        ProblemFilters.exclude[MissingClassProblem](
+          "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"),
+        ProblemFilters.exclude[MissingClassProblem](
+          "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$")
       )
     case v if v.startsWith("1.5") =>
       Seq(

http://git-wip-us.apache.org/repos/asf/spark/blob/45029bfd/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 889f970..d4b6d75 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -19,9 +19,8 @@ package org.apache.spark.sql.execution;
 
 import java.io.IOException;
 
-import com.google.common.annotations.VisibleForTesting;
-
 import org.apache.spark.SparkEnv;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -31,7 +30,6 @@ import org.apache.spark.unsafe.KVIterator;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.map.BytesToBytesMap;
 import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.memory.TaskMemoryManager;
 
 /**
  * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
@@ -218,11 +216,6 @@ public final class UnsafeFixedWidthAggregationMap {
     return map.getPeakMemoryUsedBytes();
   }
 
-  @VisibleForTesting
-  public int getNumDataPages() {
-    return map.getNumDataPages();
-  }
-
   /**
    * Free the memory associated with this map. This is idempotent and can be called multiple times.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/45029bfd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 0d3a4b3..1561691 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -17,15 +17,14 @@
 
 package org.apache.spark.sql.execution.aggregate
 
-import org.apache.spark.TaskContext
-import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
 import org.apache.spark.sql.types.StructType
 
 case class TungstenAggregate(
@@ -84,59 +83,39 @@ case class TungstenAggregate(
     val dataSize = longMetric("dataSize")
     val spillSize = longMetric("spillSize")
 
-    /**
-     * Set up the underlying unsafe data structures used before computing the parent partition.
-     * This makes sure our iterator is not starved by other operators in the same task.
-     */
-    def preparePartition(): TungstenAggregationIterator = {
-      new TungstenAggregationIterator(
-        groupingExpressions,
-        nonCompleteAggregateExpressions,
-        nonCompleteAggregateAttributes,
-        completeAggregateExpressions,
-        completeAggregateAttributes,
-        initialInputBufferOffset,
-        resultExpressions,
-        newMutableProjection,
-        child.output,
-        testFallbackStartsAt,
-        numInputRows,
-        numOutputRows,
-        dataSize,
-        spillSize)
-    }
+    child.execute().mapPartitions { iter =>
 
-    /** Compute a partition using the iterator already set up previously. */
-    def executePartition(
-        context: TaskContext,
-        partitionIndex: Int,
-        aggregationIterator: TungstenAggregationIterator,
-        parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
-      val hasInput = parentIterator.hasNext
-      if (!hasInput) {
-        // We're not using the underlying map, so we just can free it here
-        aggregationIterator.free()
-        if (groupingExpressions.isEmpty) {
+      val hasInput = iter.hasNext
+      if (!hasInput && groupingExpressions.nonEmpty) {
+        // This is a grouped aggregate and the input iterator is empty,
+        // so return an empty iterator.
+        Iterator.empty
+      } else {
+        val aggregationIterator =
+          new TungstenAggregationIterator(
+            groupingExpressions,
+            nonCompleteAggregateExpressions,
+            nonCompleteAggregateAttributes,
+            completeAggregateExpressions,
+            completeAggregateAttributes,
+            initialInputBufferOffset,
+            resultExpressions,
+            newMutableProjection,
+            child.output,
+            iter,
+            testFallbackStartsAt,
+            numInputRows,
+            numOutputRows,
+            dataSize,
+            spillSize)
+        if (!hasInput && groupingExpressions.isEmpty) {
           numOutputRows += 1
           Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
         } else {
-          // This is a grouped aggregate and the input iterator is empty,
-          // so return an empty iterator.
-          Iterator.empty
+          aggregationIterator
         }
-      } else {
-        aggregationIterator.start(parentIterator)
-        aggregationIterator
       }
     }
-
-    // Note: we need to set up the iterator in each partition before computing the
-    // parent partition, so we cannot simply use `mapPartitions` here (SPARK-9747).
-    val resultRdd = {
-      new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, TungstenAggregationIterator](
-        child.execute(), preparePartition, executePartition, preservesPartitioning = true)
-    }
-    resultRdd.asInstanceOf[RDD[InternalRow]]
   }
 
   override def simpleString: String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/45029bfd/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index fb2fc98..713a4db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -74,6 +74,8 @@ import org.apache.spark.sql.types.StructType
  *   the function used to create mutable projections.
  * @param originalInputAttributes
  *   attributes of representing input rows from `inputIter`.
+ * @param inputIter
+ *   the iterator containing input [[UnsafeRow]]s.
  */
 class TungstenAggregationIterator(
     groupingExpressions: Seq[NamedExpression],
@@ -85,6 +87,7 @@ class TungstenAggregationIterator(
     resultExpressions: Seq[NamedExpression],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
     originalInputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow],
     testFallbackStartsAt: Option[Int],
     numInputRows: LongSQLMetric,
     numOutputRows: LongSQLMetric,
@@ -92,9 +95,6 @@ class TungstenAggregationIterator(
     spillSize: LongSQLMetric)
   extends Iterator[UnsafeRow] with Logging {
 
-  // The parent partition iterator, to be initialized later in `start`
-  private[this] var inputIter: Iterator[InternalRow] = null
-
   ///////////////////////////////////////////////////////////////////////////
   // Part 1: Initializing aggregate functions.
   ///////////////////////////////////////////////////////////////////////////
@@ -486,15 +486,11 @@ class TungstenAggregationIterator(
     false // disable tracking of performance metrics
   )
 
-  // Exposed for testing
-  private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap
-
   // The function used to read and process input rows. When processing input rows,
   // it first uses hash-based aggregation by putting groups and their buffers in
   // hashMap. If we could not allocate more memory for the map, we switch to
   // sort-based aggregation (by calling switchToSortBasedAggregation).
   private def processInputs(): Unit = {
-    assert(inputIter != null, "attempted to process input when iterator was null")
     if (groupingExpressions.isEmpty) {
       // If there is no grouping expressions, we can just reuse the same buffer over and over again.
       // Note that it would be better to eliminate the hash map entirely in the future.
@@ -526,7 +522,6 @@ class TungstenAggregationIterator(
   // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have
   // been processed.
   private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = {
-    assert(inputIter != null, "attempted to process input when iterator was null")
     var i = 0
     while (!sortBased && inputIter.hasNext) {
       val newInput = inputIter.next()
@@ -567,15 +562,11 @@ class TungstenAggregationIterator(
    * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory.
    */
   private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = {
-    assert(inputIter != null, "attempted to process input when iterator was null")
     logInfo("falling back to sort based aggregation.")
     // Step 1: Get the ExternalSorter containing sorted entries of the map.
     externalSorter = hashMap.destructAndCreateExternalSorter()
 
-    // Step 2: Free the memory used by the map.
-    hashMap.free()
-
-    // Step 3: If we have aggregate function with mode Partial or Complete,
+    // Step 2: If we have aggregate function with mode Partial or Complete,
     // we need to process input rows to get aggregation buffer.
     // So, later in the sort-based aggregation iterator, we can do merge.
     // If aggregate functions are with mode Final and PartialMerge,
@@ -770,31 +761,27 @@ class TungstenAggregationIterator(
 
   /**
    * Start processing input rows.
-   * Only after this method is called will this iterator be non-empty.
    */
-  def start(parentIter: Iterator[InternalRow]): Unit = {
-    inputIter = parentIter
-    testFallbackStartsAt match {
-      case None =>
-        processInputs()
-      case Some(fallbackStartsAt) =>
-        // This is the testing path. processInputsWithControlledFallback is same as processInputs
-        // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
-        // have been processed.
-        processInputsWithControlledFallback(fallbackStartsAt)
-    }
+  testFallbackStartsAt match {
+    case None =>
+      processInputs()
+    case Some(fallbackStartsAt) =>
+      // This is the testing path. processInputsWithControlledFallback is same as processInputs
+      // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows
+      // have been processed.
+      processInputsWithControlledFallback(fallbackStartsAt)
+  }
 
-    // If we did not switch to sort-based aggregation in processInputs,
-    // we pre-load the first key-value pair from the map (to make hasNext idempotent).
-    if (!sortBased) {
-      // First, set aggregationBufferMapIterator.
-      aggregationBufferMapIterator = hashMap.iterator()
-      // Pre-load the first key-value pair from the aggregationBufferMapIterator.
-      mapIteratorHasNext = aggregationBufferMapIterator.next()
-      // If the map is empty, we just free it.
-      if (!mapIteratorHasNext) {
-        hashMap.free()
-      }
+  // If we did not switch to sort-based aggregation in processInputs,
+  // we pre-load the first key-value pair from the map (to make hasNext idempotent).
+  if (!sortBased) {
+    // First, set aggregationBufferMapIterator.
+    aggregationBufferMapIterator = hashMap.iterator()
+    // Pre-load the first key-value pair from the aggregationBufferMapIterator.
+    mapIteratorHasNext = aggregationBufferMapIterator.next()
+    // If the map is empty, we just free it.
+    if (!mapIteratorHasNext) {
+      hashMap.free()
     }
   }
 
@@ -868,13 +855,16 @@ class TungstenAggregationIterator(
    * Generate a output row when there is no input and there is no grouping expression.
    */
   def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
-    assert(groupingExpressions.isEmpty)
-    assert(inputIter == null)
-    generateOutput(UnsafeRow.createFromByteArray(0, 0), initialAggregationBuffer)
-  }
-
-  /** Free memory used in the underlying map. */
-  def free(): Unit = {
-    hashMap.free()
+    if (groupingExpressions.isEmpty) {
+      sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
+      // We create a output row and copy it. So, we can free the map.
+      val resultCopy =
+        generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy()
+      hashMap.free()
+      resultCopy
+    } else {
+      throw new IllegalStateException(
+        "This method should not be called when groupingExpressions is not empty.")
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/45029bfd/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index dd92dda..1a3832a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
@@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalSorter
-import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext}
+import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // This file defines various sort operators.
@@ -77,6 +77,7 @@ case class Sort(
  * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
  *                           spill every `frequency` records.
  */
+
 case class TungstenSort(
     sortOrder: Seq[SortOrder],
     global: Boolean,
@@ -106,11 +107,7 @@ case class TungstenSort(
     val dataSize = longMetric("dataSize")
     val spillSize = longMetric("spillSize")
 
-    /**
-     * Set up the sorter in each partition before computing the parent partition.
-     * This makes sure our sorter is not starved by other sorters used in the same task.
-     */
-    def preparePartition(): UnsafeExternalRowSorter = {
+    child.execute().mapPartitions { iter =>
       val ordering = newOrdering(sortOrder, childOutput)
 
       // The comparator for comparing prefix
@@ -131,33 +128,20 @@ case class TungstenSort(
       if (testSpillFrequency > 0) {
         sorter.setTestSpillFrequency(testSpillFrequency)
       }
-      sorter
-    }
 
-    /** Compute a partition using the sorter already set up previously. */
-    def executePartition(
-        taskContext: TaskContext,
-        partitionIndex: Int,
-        sorter: UnsafeExternalRowSorter,
-        parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
       // Remember spill data size of this task before execute this operator so that we can
       // figure out how many bytes we spilled for this operator.
       val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
 
-      val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]])
+      val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
 
       dataSize += sorter.getPeakMemoryUsage
       spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
 
-      taskContext.internalMetricsToAccumulators(
+      TaskContext.get().internalMetricsToAccumulators(
         InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
       sortedIterator
     }
-
-    // Note: we need to set up the external sorter in each partition before computing
-    // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709).
-    new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter](
-      child.execute(), preparePartition, executePartition, preservesPartitioning = true)
   }
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org