You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2015/10/30 07:39:09 UTC

[1/3] spark git commit: [SPARK-10342] [SPARK-10309] [SPARK-10474] [SPARK-10929] [SQL] Cooperative memory management

Repository: spark
Updated Branches:
  refs/heads/master d89be0bf8 -> 56419cf11


http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index dbf4863..a386236 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -24,7 +24,7 @@ import scala.util.{Try, Random}
 import org.scalatest.Matchers
 
 import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite}
-import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
 import org.apache.spark.sql.test.SharedSQLContext
@@ -48,7 +48,7 @@ class UnsafeFixedWidthAggregationMapSuite
   private def emptyAggregationBuffer: InternalRow = InternalRow(0)
   private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
 
-  private var memoryManager: GrantEverythingMemoryManager = null
+  private var memoryManager: TestMemoryManager = null
   private var taskMemoryManager: TaskMemoryManager = null
 
   def testWithMemoryLeakDetection(name: String)(f: => Unit) {
@@ -62,7 +62,7 @@ class UnsafeFixedWidthAggregationMapSuite
 
     test(name) {
       val conf = new SparkConf().set("spark.unsafe.offHeap", "false")
-      memoryManager = new GrantEverythingMemoryManager(conf)
+      memoryManager = new TestMemoryManager(conf)
       taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
 
       TaskContext.setTaskContext(new TaskContextImpl(
@@ -193,10 +193,6 @@ class UnsafeFixedWidthAggregationMapSuite
     // Convert the map into a sorter
     val sorter = map.destructAndCreateExternalSorter()
 
-    withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
-      assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
-    }
-
     // Add more keys to the sorter and make sure the results come out sorted.
     val additionalKeys = randomStrings(1024)
     val keyConverter = UnsafeProjection.create(groupKeySchema)
@@ -208,7 +204,7 @@ class UnsafeFixedWidthAggregationMapSuite
       sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
 
       if ((i % 100) == 0) {
-        memoryManager.markExecutionAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemoryOnce()
         sorter.closeCurrentPage()
       }
     }
@@ -251,7 +247,7 @@ class UnsafeFixedWidthAggregationMapSuite
       sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
 
       if ((i % 100) == 0) {
-        memoryManager.markExecutionAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemoryOnce()
         sorter.closeCurrentPage()
       }
     }
@@ -294,16 +290,12 @@ class UnsafeFixedWidthAggregationMapSuite
     // Convert the map into a sorter. Right now, it contains one record.
     val sorter = map.destructAndCreateExternalSorter()
 
-    withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
-      assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
-    }
-
     // Add more keys to the sorter and make sure the results come out sorted.
     (1 to 4096).foreach { i =>
       sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0))
 
       if ((i % 100) == 0) {
-        memoryManager.markExecutionAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemoryOnce()
         sorter.closeCurrentPage()
       }
     }
@@ -342,7 +334,7 @@ class UnsafeFixedWidthAggregationMapSuite
       buf.setInt(0, str.length)
     }
     // Simulate running out of space
-    memoryManager.markExecutionAsOutOfMemory()
+    memoryManager.limit(0)
     val str = rand.nextString(1024)
     val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
     assert(buf == null)

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 13dc175..7b80963 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import scala.util.Random
 
 import org.apache.spark._
-import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
 import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
@@ -109,7 +109,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
       pageSize: Long,
       spill: Boolean): Unit = {
     val memoryManager =
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
     val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
     TaskContext.setTaskContext(new TaskContextImpl(
       stageId = 0,
@@ -128,7 +128,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
       sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
       // 1% chance we will spill
       if (rand.nextDouble() < 0.01 && spill) {
-        memoryManager.markExecutionAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemoryOnce()
         sorter.closeCurrentPage()
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
deleted file mode 100644
index 475037b..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark._
-import org.apache.spark.memory.TaskMemoryManager
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.test.SharedSQLContext
-
-class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
-
-  test("memory acquired on construction") {
-    val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.memoryManager, 0)
-    val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
-    TaskContext.setTaskContext(taskContext)
-
-    // Assert that a page is allocated before processing starts
-    var iter: TungstenAggregationIterator = null
-    try {
-      val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => {
-        () => new InterpretedMutableProjection(expr, schema)
-      }
-      val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
-      iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
-        0, Seq.empty, newMutableProjection, Seq.empty, None,
-        dummyAccum, dummyAccum, dummyAccum, dummyAccum)
-      val numPages = iter.getHashMap.getNumDataPages
-      assert(numPages === 1)
-    } finally {
-      // Clean up
-      if (iter != null) {
-        iter.free()
-      }
-      TaskContext.unset()
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index ebe90d9..09847ce 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -23,6 +23,8 @@ import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.Map;
 
+import org.apache.spark.unsafe.Platform;
+
 /**
  * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array.
  */
@@ -45,9 +47,6 @@ public class HeapMemoryAllocator implements MemoryAllocator {
 
   @Override
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
-    if (size % 8 != 0) {
-      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
-    }
     if (shouldPool(size)) {
       synchronized (this) {
         final LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
@@ -64,8 +63,8 @@ public class HeapMemoryAllocator implements MemoryAllocator {
         }
       }
     }
-    long[] array = new long[(int) (size / 8)];
-    return MemoryBlock.fromLongArray(array);
+    long[] array = new long[(int) ((size + 7) / 8)];
+    return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
index cda7826..98ce711 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -26,9 +26,6 @@ public class UnsafeMemoryAllocator implements MemoryAllocator {
 
   @Override
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
-    if (size % 8 != 0) {
-      throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
-    }
     long address = Platform.allocateMemory(size);
     return new MemoryBlock(null, address, size);
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[3/3] spark git commit: [SPARK-10342] [SPARK-10309] [SPARK-10474] [SPARK-10929] [SQL] Cooperative memory management

Posted by da...@apache.org.
[SPARK-10342] [SPARK-10309] [SPARK-10474] [SPARK-10929] [SQL] Cooperative memory management

This PR introduce a mechanism to call spill() on those SQL operators that support spilling (for example, BytesToBytesMap, UnsafeExternalSorter and ShuffleExternalSorter) if there is not enough memory for execution. The preserved first page is needed anymore, so removed.

Other Spillable objects in Spark core (ExternalSorter and AppendOnlyMap) are not included in this PR, but those could benefit from this (trigger others' spilling).

The PrepareRDD may be not needed anymore, could be removed in follow up PR.

The following script will fail with OOM before this PR, finished in 150 seconds with 2G heap (also works in 1.5 branch, with similar duration).

```python
sqlContext.setConf("spark.sql.shuffle.partitions", "1")
df = sqlContext.range(1<<25).selectExpr("id", "repeat(id, 2) as s")
df2 = df.select(df.id.alias('id2'), df.s.alias('s2'))
j = df.join(df2, df.id==df2.id2).groupBy(df.id).max("id", "id2")
j.explain()
print j.count()
```

For thread-safety, here what I'm got:

1) Without calling spill(), the operators should only be used by single thread, no safety problems.

2) spill() could be triggered in two cases, triggered by itself, or by other operators. we can check trigger == this in spill(), so it's still in the same thread, so safety problems.

3) if it's triggered by other operators (right now cache will not trigger spill()), we only spill the data into disk when it's in scanning stage (building is finished), so the in-memory sorter or memory pages are read-only, we only need to synchronize the iterator and change it.

4) During scanning, the iterator will only use one record in one page, we can't free this page, because the downstream is currently using it (used by UnsafeRow or other objects). In BytesToBytesMap, we just skip the current page, and dump all others into disk. In UnsafeExternalSorter, we keep the page that is used by current record (having the same baseObject), free it when loading the next record. In ShuffleExternalSorter, the spill() will not trigger during scanning.

5) In order to avoid deadlock, we didn't call acquireMemory during spill (so we reused the pointer array in InMemorySorter).

Author: Davies Liu <da...@databricks.com>

Closes #9241 from davies/force_spill.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56419cf1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56419cf1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56419cf1

Branch: refs/heads/master
Commit: 56419cf11f769c80f391b45dc41b3c7101cc5ff4
Parents: d89be0b
Author: Davies Liu <da...@databricks.com>
Authored: Thu Oct 29 23:38:06 2015 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Thu Oct 29 23:38:06 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/memory/MemoryConsumer.java | 128 ++++++
 .../apache/spark/memory/TaskMemoryManager.java  | 138 ++++--
 .../shuffle/sort/ShuffleExternalSorter.java     | 210 +++------
 .../shuffle/sort/ShuffleInMemorySorter.java     |  50 ++-
 .../spark/shuffle/sort/UnsafeShuffleWriter.java |   6 -
 .../spark/unsafe/map/BytesToBytesMap.java       | 430 +++++++++++--------
 .../unsafe/sort/UnsafeExternalSorter.java       | 426 +++++++++---------
 .../unsafe/sort/UnsafeInMemorySorter.java       |  60 ++-
 .../unsafe/sort/UnsafeSorterSpillReader.java    |   6 +-
 .../unsafe/sort/UnsafeSorterSpillWriter.java    |   2 +-
 .../org/apache/spark/memory/MemoryManager.scala |   9 +-
 .../spark/util/collection/Spillable.scala       |   4 +-
 .../spark/memory/TaskMemoryManagerSuite.java    |  77 +++-
 .../shuffle/sort/PackedRecordPointerSuite.java  |  30 +-
 .../sort/ShuffleInMemorySorterSuite.java        |   6 +-
 .../shuffle/sort/UnsafeShuffleWriterSuite.java  |  38 +-
 .../map/AbstractBytesToBytesMapSuite.java       | 149 ++++++-
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |  97 +++--
 .../unsafe/sort/UnsafeInMemorySorterSuite.java  |  20 +-
 .../scala/org/apache/spark/FailureSuite.scala   |   4 +-
 .../memory/GrantEverythingMemoryManager.scala   |  54 ---
 .../spark/memory/MemoryManagerSuite.scala       |  60 +--
 .../apache/spark/memory/TestMemoryManager.scala |  70 +++
 .../sql/execution/UnsafeExternalRowSorter.java  |   7 +-
 .../UnsafeFixedWidthAggregationMap.java         |   2 +-
 .../sql/execution/UnsafeKVExternalSorter.java   |  19 +-
 .../UnsafeFixedWidthAggregationMapSuite.scala   |  22 +-
 .../execution/UnsafeKVExternalSorterSuite.scala |   6 +-
 .../TungstenAggregationIteratorSuite.scala      |  54 ---
 .../unsafe/memory/HeapMemoryAllocator.java      |   9 +-
 .../unsafe/memory/UnsafeMemoryAllocator.java    |   3 -
 31 files changed, 1316 insertions(+), 880 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
new file mode 100644
index 0000000..008799c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+
+import java.io.IOException;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+
+/**
+ * An memory consumer of TaskMemoryManager, which support spilling.
+ */
+public abstract class MemoryConsumer {
+
+  private final TaskMemoryManager taskMemoryManager;
+  private final long pageSize;
+  private long used;
+
+  protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
+    this.taskMemoryManager = taskMemoryManager;
+    this.pageSize = pageSize;
+    this.used = 0;
+  }
+
+  protected MemoryConsumer(TaskMemoryManager taskMemoryManager) {
+    this(taskMemoryManager, taskMemoryManager.pageSizeBytes());
+  }
+
+  /**
+   * Returns the size of used memory in bytes.
+   */
+  long getUsed() {
+    return used;
+  }
+
+  /**
+   * Force spill during building.
+   *
+   * For testing.
+   */
+  public void spill() throws IOException {
+    spill(Long.MAX_VALUE, this);
+  }
+
+  /**
+   * Spill some data to disk to release memory, which will be called by TaskMemoryManager
+   * when there is not enough memory for the task.
+   *
+   * This should be implemented by subclass.
+   *
+   * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill().
+   *
+   * @param size the amount of memory should be released
+   * @param trigger the MemoryConsumer that trigger this spilling
+   * @return the amount of released memory in bytes
+   * @throws IOException
+   */
+  public abstract long spill(long size, MemoryConsumer trigger) throws IOException;
+
+  /**
+   * Acquire `size` bytes memory.
+   *
+   * If there is not enough memory, throws OutOfMemoryError.
+   */
+  protected void acquireMemory(long size) {
+    long got = taskMemoryManager.acquireExecutionMemory(size, this);
+    if (got < size) {
+      taskMemoryManager.releaseExecutionMemory(got, this);
+      taskMemoryManager.showMemoryUsage();
+      throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got);
+    }
+    used += got;
+  }
+
+  /**
+   * Release `size` bytes memory.
+   */
+  protected void releaseMemory(long size) {
+    used -= size;
+    taskMemoryManager.releaseExecutionMemory(size, this);
+  }
+
+  /**
+   * Allocate a memory block with at least `required` bytes.
+   *
+   * Throws IOException if there is not enough memory.
+   *
+   * @throws OutOfMemoryError
+   */
+  protected MemoryBlock allocatePage(long required) {
+    MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this);
+    if (page == null || page.size() < required) {
+      long got = 0;
+      if (page != null) {
+        got = page.size();
+        freePage(page);
+      }
+      taskMemoryManager.showMemoryUsage();
+      throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
+    }
+    used += page.size();
+    return page;
+  }
+
+  /**
+   * Free a memory block.
+   */
+  protected void freePage(MemoryBlock page) {
+    used -= page.size();
+    taskMemoryManager.freePage(page, this);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 7b31c90..4230575 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -17,13 +17,18 @@
 
 package org.apache.spark.memory;
 
-import java.util.*;
+import javax.annotation.concurrent.GuardedBy;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.BitSet;
+import java.util.HashSet;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.util.Utils;
 
 /**
  * Manages the memory allocated by an individual task.
@@ -101,29 +106,104 @@ public class TaskMemoryManager {
   private final boolean inHeap;
 
   /**
+   * The size of memory granted to each consumer.
+   */
+  @GuardedBy("this")
+  private final HashSet<MemoryConsumer> consumers;
+
+  /**
    * Construct a new TaskMemoryManager.
    */
   public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
     this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap();
     this.memoryManager = memoryManager;
     this.taskAttemptId = taskAttemptId;
+    this.consumers = new HashSet<>();
   }
 
   /**
-   * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
+   * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call
+   * spill() of consumers to release more memory.
+   *
    * @return number of bytes successfully granted (<= N).
    */
-  public long acquireExecutionMemory(long size) {
-    return memoryManager.acquireExecutionMemory(size, taskAttemptId);
+  public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
+    assert(required >= 0);
+    synchronized (this) {
+      long got = memoryManager.acquireExecutionMemory(required, taskAttemptId);
+
+      // try to release memory from other consumers first, then we can reduce the frequency of
+      // spilling, avoid to have too many spilled files.
+      if (got < required) {
+        // Call spill() on other consumers to release memory
+        for (MemoryConsumer c: consumers) {
+          if (c != null && c != consumer && c.getUsed() > 0) {
+            try {
+              long released = c.spill(required - got, consumer);
+              if (released > 0) {
+                logger.info("Task {} released {} from {} for {}", taskAttemptId,
+                  Utils.bytesToString(released), c, consumer);
+                got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId);
+                if (got >= required) {
+                  break;
+                }
+              }
+            } catch (IOException e) {
+              logger.error("error while calling spill() on " + c, e);
+              throw new OutOfMemoryError("error while calling spill() on " + c + " : "
+                + e.getMessage());
+            }
+          }
+        }
+      }
+
+      // call spill() on itself
+      if (got < required && consumer != null) {
+        try {
+          long released = consumer.spill(required - got, consumer);
+          if (released > 0) {
+            logger.info("Task {} released {} from itself ({})", taskAttemptId,
+              Utils.bytesToString(released), consumer);
+            got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId);
+          }
+        } catch (IOException e) {
+          logger.error("error while calling spill() on " + consumer, e);
+          throw new OutOfMemoryError("error while calling spill() on " + consumer + " : "
+            + e.getMessage());
+        }
+      }
+
+      consumers.add(consumer);
+      logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
+      return got;
+    }
   }
 
   /**
-   * Release N bytes of execution memory.
+   * Release N bytes of execution memory for a MemoryConsumer.
    */
-  public void releaseExecutionMemory(long size) {
+  public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
+    logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer);
     memoryManager.releaseExecutionMemory(size, taskAttemptId);
   }
 
+  /**
+   * Dump the memory usage of all consumers.
+   */
+  public void showMemoryUsage() {
+    logger.info("Memory used in task " + taskAttemptId);
+    synchronized (this) {
+      for (MemoryConsumer c: consumers) {
+        if (c.getUsed() > 0) {
+          logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed()));
+        }
+      }
+    }
+  }
+
+  /**
+   * Return the page size in bytes.
+   */
   public long pageSizeBytes() {
     return memoryManager.pageSizeBytes();
   }
@@ -134,42 +214,40 @@ public class TaskMemoryManager {
    *
    * Returns `null` if there was not enough memory to allocate the page.
    */
-  public MemoryBlock allocatePage(long size) {
+  public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
     if (size > MAXIMUM_PAGE_SIZE_BYTES) {
       throw new IllegalArgumentException(
         "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
     }
 
+    long acquired = acquireExecutionMemory(size, consumer);
+    if (acquired <= 0) {
+      return null;
+    }
+
     final int pageNumber;
     synchronized (this) {
       pageNumber = allocatedPages.nextClearBit(0);
       if (pageNumber >= PAGE_TABLE_SIZE) {
+        releaseExecutionMemory(acquired, consumer);
         throw new IllegalStateException(
           "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
       }
       allocatedPages.set(pageNumber);
     }
-    final long acquiredExecutionMemory = acquireExecutionMemory(size);
-    if (acquiredExecutionMemory != size) {
-      releaseExecutionMemory(acquiredExecutionMemory);
-      synchronized (this) {
-        allocatedPages.clear(pageNumber);
-      }
-      return null;
-    }
-    final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size);
+    final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired);
     page.pageNumber = pageNumber;
     pageTable[pageNumber] = page;
     if (logger.isTraceEnabled()) {
-      logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
+      logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired);
     }
     return page;
   }
 
   /**
-   * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
+   * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
    */
-  public void freePage(MemoryBlock page) {
+  public void freePage(MemoryBlock page, MemoryConsumer consumer) {
     assert (page.pageNumber != -1) :
       "Called freePage() on memory that wasn't allocated with allocatePage()";
     assert(allocatedPages.get(page.pageNumber));
@@ -182,14 +260,14 @@ public class TaskMemoryManager {
     }
     long pageSize = page.size();
     memoryManager.tungstenMemoryAllocator().free(page);
-    releaseExecutionMemory(pageSize);
+    releaseExecutionMemory(pageSize, consumer);
   }
 
   /**
    * Given a memory page and offset within that page, encode this address into a 64-bit long.
    * This address will remain valid as long as the corresponding page has not been freed.
    *
-   * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/
+   * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/
    * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
    *                     this should be the value that you would pass as the base offset into an
    *                     UNSAFE call (e.g. page.baseOffset() + something).
@@ -261,17 +339,17 @@ public class TaskMemoryManager {
    * value can be used to detect memory leaks.
    */
   public long cleanUpAllAllocatedMemory() {
-    long freedBytes = 0;
-    for (MemoryBlock page : pageTable) {
-      if (page != null) {
-        freedBytes += page.size();
-        freePage(page);
+    synchronized (this) {
+      Arrays.fill(pageTable, null);
+      for (MemoryConsumer c: consumers) {
+        if (c != null && c.getUsed() > 0) {
+          // In case of failed task, it's normal to see leaked memory
+          logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c);
+        }
       }
+      consumers.clear();
     }
-
-    freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
-
-    return freedBytes;
+    return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index f43236f..400d852 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -31,15 +31,15 @@ import org.slf4j.LoggerFactory;
 import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.TempShuffleBlockId;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 /**
@@ -58,23 +58,18 @@ import org.apache.spark.util.Utils;
  * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
  * specialized merge procedure that avoids extra serialization/deserialization.
  */
-final class ShuffleExternalSorter {
+final class ShuffleExternalSorter extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
 
   @VisibleForTesting
   static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
 
-  private final int initialSize;
   private final int numPartitions;
-  private final int pageSizeBytes;
-  @VisibleForTesting
-  final int maxRecordSizeBytes;
   private final TaskMemoryManager taskMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
   private final ShuffleWriteMetrics writeMetrics;
-  private long numRecordsInsertedSinceLastSpill = 0;
 
   /** Force this sorter to spill when there are this many elements in memory. For testing only */
   private final long numElementsForSpillThreshold;
@@ -98,8 +93,7 @@ final class ShuffleExternalSorter {
   // These variables are reset after spilling:
   @Nullable private ShuffleInMemorySorter inMemSorter;
   @Nullable private MemoryBlock currentPage = null;
-  private long currentPagePosition = -1;
-  private long freeSpaceInCurrentPage = 0;
+  private long pageCursor = -1;
 
   public ShuffleExternalSorter(
       TaskMemoryManager memoryManager,
@@ -108,42 +102,21 @@ final class ShuffleExternalSorter {
       int initialSize,
       int numPartitions,
       SparkConf conf,
-      ShuffleWriteMetrics writeMetrics) throws IOException {
+      ShuffleWriteMetrics writeMetrics) {
+    super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES,
+      memoryManager.pageSizeBytes()));
     this.taskMemoryManager = memoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
-    this.initialSize = initialSize;
-    this.peakMemoryUsedBytes = initialSize;
     this.numPartitions = numPartitions;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.numElementsForSpillThreshold =
       conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
-    this.pageSizeBytes = (int) Math.min(
-      PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes());
-    this.maxRecordSizeBytes = pageSizeBytes - 4;
     this.writeMetrics = writeMetrics;
-    initializeForWriting();
-
-    // preserve first page to ensure that we have at least one page to work with. Otherwise,
-    // other operators in the same task may starve this sorter (SPARK-9709).
-    acquireNewPageIfNecessary(pageSizeBytes);
-  }
-
-  /**
-   * Allocates new sort data structures. Called when creating the sorter and after each spill.
-   */
-  private void initializeForWriting() throws IOException {
-    // TODO: move this sizing calculation logic into a static method of sorter:
-    final long memoryRequested = initialSize * 8L;
-    final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested);
-    if (memoryAcquired != memoryRequested) {
-      taskMemoryManager.releaseExecutionMemory(memoryAcquired);
-      throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
-    }
-
+    acquireMemory(initialSize * 8L);
     this.inMemSorter = new ShuffleInMemorySorter(initialSize);
-    numRecordsInsertedSinceLastSpill = 0;
+    this.peakMemoryUsedBytes = getMemoryUsage();
   }
 
   /**
@@ -242,6 +215,8 @@ final class ShuffleExternalSorter {
       }
     }
 
+    inMemSorter.reset();
+
     if (!isLastFile) {  // i.e. this is a spill file
       // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
       // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
@@ -266,9 +241,12 @@ final class ShuffleExternalSorter {
   /**
    * Sort and spill the current records in response to memory pressure.
    */
-  @VisibleForTesting
-  void spill() throws IOException {
-    assert(inMemSorter != null);
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) {
+      return 0L;
+    }
+
     logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
       Thread.currentThread().getId(),
       Utils.bytesToString(getMemoryUsage()),
@@ -276,13 +254,9 @@ final class ShuffleExternalSorter {
       spills.size() > 1 ? " times" : " time");
 
     writeSortedFile(false);
-    final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
-    inMemSorter = null;
-    taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage);
     final long spillSize = freeMemory();
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
-
-    initializeForWriting();
+    return spillSize;
   }
 
   private long getMemoryUsage() {
@@ -312,18 +286,12 @@ final class ShuffleExternalSorter {
     updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
-      taskMemoryManager.freePage(block);
       memoryFreed += block.size();
-    }
-    if (inMemSorter != null) {
-      long sorterMemoryUsage = inMemSorter.getMemoryUsage();
-      inMemSorter = null;
-      taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
+      freePage(block);
     }
     allocatedPages.clear();
     currentPage = null;
-    currentPagePosition = -1;
-    freeSpaceInCurrentPage = 0;
+    pageCursor = 0;
     return memoryFreed;
   }
 
@@ -332,16 +300,16 @@ final class ShuffleExternalSorter {
    */
   public void cleanupResources() {
     freeMemory();
+    if (inMemSorter != null) {
+      long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+      inMemSorter = null;
+      releaseMemory(sorterMemoryUsage);
+    }
     for (SpillInfo spill : spills) {
       if (spill.file.exists() && !spill.file.delete()) {
         logger.error("Unable to delete spill file {}", spill.file.getPath());
       }
     }
-    if (inMemSorter != null) {
-      long sorterMemoryUsage = inMemSorter.getMemoryUsage();
-      inMemSorter = null;
-      taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
-    }
   }
 
   /**
@@ -352,16 +320,27 @@ final class ShuffleExternalSorter {
   private void growPointerArrayIfNecessary() throws IOException {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
-      logger.debug("Attempting to expand sort pointer array");
-      final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
-      final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
-      final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray);
-      if (memoryAcquired < memoryToGrowPointerArray) {
-        taskMemoryManager.releaseExecutionMemory(memoryAcquired);
-        spill();
+      long used = inMemSorter.getMemoryUsage();
+      long needed = used + inMemSorter.getMemoryToExpand();
+      try {
+        acquireMemory(needed);  // could trigger spilling
+      } catch (OutOfMemoryError e) {
+        // should have trigger spilling
+        assert(inMemSorter.hasSpaceForAnotherRecord());
+        return;
+      }
+      // check if spilling is triggered or not
+      if (inMemSorter.hasSpaceForAnotherRecord()) {
+        releaseMemory(needed);
       } else {
-        inMemSorter.expandPointerArray();
-        taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage);
+        try {
+          inMemSorter.expandPointerArray();
+          releaseMemory(used);
+        } catch (OutOfMemoryError oom) {
+          // Just in case that JVM had run out of memory
+          releaseMemory(needed);
+          spill();
+        }
       }
     }
   }
@@ -370,96 +349,46 @@ final class ShuffleExternalSorter {
    * Allocates more memory in order to insert an additional record. This will request additional
    * memory from the memory manager and spill if the requested memory can not be obtained.
    *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   * @param required the required space in the data page, in bytes, including space for storing
    *                      the record size. This must be less than or equal to the page size (records
    *                      that exceed the page size are handled via a different code path which uses
    *                      special overflow pages).
    */
-  private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
-    growPointerArrayIfNecessary();
-    if (requiredSpace > freeSpaceInCurrentPage) {
-      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
-        freeSpaceInCurrentPage);
-      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
-      // without using the free space at the end of the current page. We should also do this for
-      // BytesToBytesMap.
-      if (requiredSpace > pageSizeBytes) {
-        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
-          pageSizeBytes + ")");
-      } else {
-        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-        if (currentPage == null) {
-          spill();
-          currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-          if (currentPage == null) {
-            throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
-          }
-        }
-        currentPagePosition = currentPage.getBaseOffset();
-        freeSpaceInCurrentPage = pageSizeBytes;
-        allocatedPages.add(currentPage);
-      }
+  private void acquireNewPageIfNecessary(int required) {
+    if (currentPage == null ||
+      pageCursor + required > currentPage.getBaseOffset() + currentPage.size() ) {
+      // TODO: try to find space in previous pages
+      currentPage = allocatePage(required);
+      pageCursor = currentPage.getBaseOffset();
+      allocatedPages.add(currentPage);
     }
   }
 
   /**
    * Write a record to the shuffle sorter.
    */
-  public void insertRecord(
-      Object recordBaseObject,
-      long recordBaseOffset,
-      int lengthInBytes,
-      int partitionId) throws IOException {
+  public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
+    throws IOException {
 
-    if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
+    // for tests
+    assert(inMemSorter != null);
+    if (inMemSorter.numRecords() > numElementsForSpillThreshold) {
       spill();
     }
 
     growPointerArrayIfNecessary();
     // Need 4 bytes to store the record length.
-    final int totalSpaceRequired = lengthInBytes + 4;
-
-    // --- Figure out where to insert the new record ----------------------------------------------
-
-    final MemoryBlock dataPage;
-    long dataPagePosition;
-    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
-    if (useOverflowPage) {
-      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
-      // The record is larger than the page size, so allocate a special overflow page just to hold
-      // that record.
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-      if (overflowPage == null) {
-        spill();
-        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-        if (overflowPage == null) {
-          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
-        }
-      }
-      allocatedPages.add(overflowPage);
-      dataPage = overflowPage;
-      dataPagePosition = overflowPage.getBaseOffset();
-    } else {
-      // The record is small enough to fit in a regular data page, but the current page might not
-      // have enough space to hold it (or no pages have been allocated yet).
-      acquireNewPageIfNecessary(totalSpaceRequired);
-      dataPage = currentPage;
-      dataPagePosition = currentPagePosition;
-      // Update bookkeeping information
-      freeSpaceInCurrentPage -= totalSpaceRequired;
-      currentPagePosition += totalSpaceRequired;
-    }
-    final Object dataPageBaseObject = dataPage.getBaseObject();
-
-    final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
-    Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
-    dataPagePosition += 4;
-    Platform.copyMemory(
-      recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
-    assert(inMemSorter != null);
+    final int required = length + 4;
+    acquireNewPageIfNecessary(required);
+
+    assert(currentPage != null);
+    final Object base = currentPage.getBaseObject();
+    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+    Platform.putInt(base, pageCursor, length);
+    pageCursor += 4;
+    Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
+    pageCursor += length;
     inMemSorter.insertRecord(recordAddress, partitionId);
-    numRecordsInsertedSinceLastSpill += 1;
   }
 
   /**
@@ -475,6 +404,9 @@ final class ShuffleExternalSorter {
         // Do not count the final file towards the spill count.
         writeSortedFile(true);
         freeMemory();
+        long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+        inMemSorter = null;
+        releaseMemory(sorterMemoryUsage);
       }
       return spills.toArray(new SpillInfo[spills.size()]);
     } catch (IOException e) {

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index a8dee6c..e630575 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -37,33 +37,51 @@ final class ShuffleInMemorySorter {
    * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
    * records.
    */
-  private long[] pointerArray;
+  private long[] array;
 
   /**
    * The position in the pointer array where new records can be inserted.
    */
-  private int pointerArrayInsertPosition = 0;
+  private int pos = 0;
 
   public ShuffleInMemorySorter(int initialSize) {
     assert (initialSize > 0);
-    this.pointerArray = new long[initialSize];
-    this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
+    this.array = new long[initialSize];
+    this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
   }
 
-  public void expandPointerArray() {
-    final long[] oldArray = pointerArray;
+  public int numRecords() {
+    return pos;
+  }
+
+  public void reset() {
+    pos = 0;
+  }
+
+  private int newLength() {
     // Guard against overflow:
-    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
-    pointerArray = new long[newLength];
-    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+    return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
+  }
+
+  /**
+   * Returns the memory needed to expand
+   */
+  public long getMemoryToExpand() {
+    return ((long) (newLength() - array.length)) * 8;
+  }
+
+  public void expandPointerArray() {
+    final long[] oldArray = array;
+    array = new long[newLength()];
+    System.arraycopy(oldArray, 0, array, 0, oldArray.length);
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pointerArrayInsertPosition + 1 < pointerArray.length;
+    return pos < array.length;
   }
 
   public long getMemoryUsage() {
-    return pointerArray.length * 8L;
+    return array.length * 8L;
   }
 
   /**
@@ -78,15 +96,15 @@ final class ShuffleInMemorySorter {
    */
   public void insertRecord(long recordPointer, int partitionId) {
     if (!hasSpaceForAnotherRecord()) {
-      if (pointerArray.length == Integer.MAX_VALUE) {
+      if (array.length == Integer.MAX_VALUE) {
         throw new IllegalStateException("Sort pointer array has reached maximum size");
       } else {
         expandPointerArray();
       }
     }
-    pointerArray[pointerArrayInsertPosition] =
+    array[pos] =
         PackedRecordPointer.packPointer(recordPointer, partitionId);
-    pointerArrayInsertPosition++;
+    pos++;
   }
 
   /**
@@ -118,7 +136,7 @@ final class ShuffleInMemorySorter {
    * Return an iterator over record pointers in sorted order.
    */
   public ShuffleSorterIterator getSortedIterator() {
-    sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
-    return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+    sorter.sort(array, 0, pos, SORT_COMPARATOR);
+    return new ShuffleSorterIterator(pos, array);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index f6c5c94..e19b378 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -127,12 +127,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     open();
   }
 
-  @VisibleForTesting
-  public int maxRecordSizeBytes() {
-    assert(sorter != null);
-    return sorter.maxRecordSizeBytes;
-  }
-
   private void updatePeakMemoryUsed() {
     // sorter can be null if this writer is closed
     if (sorter != null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index f035bda..e36709c 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -18,14 +18,20 @@
 package org.apache.spark.unsafe.map;
 
 import javax.annotation.Nullable;
+import java.io.File;
+import java.io.IOException;
 import java.util.Iterator;
 import java.util.LinkedList;
-import java.util.List;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.spark.SparkEnv;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.storage.BlockManager;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.array.LongArray;
@@ -33,7 +39,8 @@ import org.apache.spark.unsafe.bitset.BitSet;
 import org.apache.spark.unsafe.hash.Murmur3_x86_32;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
 
 /**
  * An append-only hash map where keys and values are contiguous regions of bytes.
@@ -54,7 +61,7 @@ import org.apache.spark.memory.TaskMemoryManager;
  * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
  * so we can pass records from this map directly into the sorter to sort records in place.
  */
-public final class BytesToBytesMap {
+public final class BytesToBytesMap extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
 
@@ -62,27 +69,22 @@ public final class BytesToBytesMap {
 
   private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
 
-  /**
-   * Special record length that is placed after the last record in a data page.
-   */
-  private static final int END_OF_PAGE_MARKER = -1;
-
   private final TaskMemoryManager taskMemoryManager;
 
   /**
    * A linked list for tracking all allocated data pages so that we can free all of our memory.
    */
-  private final List<MemoryBlock> dataPages = new LinkedList<MemoryBlock>();
+  private final LinkedList<MemoryBlock> dataPages = new LinkedList<>();
 
   /**
    * The data page that will be used to store keys and values for new hashtable entries. When this
    * page becomes full, a new page will be allocated and this pointer will change to point to that
    * new page.
    */
-  private MemoryBlock currentDataPage = null;
+  private MemoryBlock currentPage = null;
 
   /**
-   * Offset into `currentDataPage` that points to the location where new data can be inserted into
+   * Offset into `currentPage` that points to the location where new data can be inserted into
    * the page. This does not incorporate the page's base offset.
    */
   private long pageCursor = 0;
@@ -117,6 +119,11 @@ public final class BytesToBytesMap {
   // absolute memory addresses.
 
   /**
+   * Whether or not the longArray can grow. We will not insert more elements if it's false.
+   */
+  private boolean canGrowArray = true;
+
+  /**
    * A {@link BitSet} used to track location of the map where the key is set.
    * Size of the bitset should be half of the size of the long array.
    */
@@ -164,13 +171,20 @@ public final class BytesToBytesMap {
 
   private long peakMemoryUsedBytes = 0L;
 
+  private final BlockManager blockManager;
+  private volatile MapIterator destructiveIterator = null;
+  private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
+
   public BytesToBytesMap(
       TaskMemoryManager taskMemoryManager,
+      BlockManager blockManager,
       int initialCapacity,
       double loadFactor,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
+    super(taskMemoryManager, pageSizeBytes);
     this.taskMemoryManager = taskMemoryManager;
+    this.blockManager = blockManager;
     this.loadFactor = loadFactor;
     this.loc = new Location();
     this.pageSizeBytes = pageSizeBytes;
@@ -187,18 +201,13 @@ public final class BytesToBytesMap {
         TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
     }
     allocate(initialCapacity);
-
-    // Acquire a new page as soon as we construct the map to ensure that we have at least
-    // one page to work with. Otherwise, other operators in the same task may starve this
-    // map (SPARK-9747).
-    acquireNewPage();
   }
 
   public BytesToBytesMap(
       TaskMemoryManager taskMemoryManager,
       int initialCapacity,
       long pageSizeBytes) {
-    this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+    this(taskMemoryManager, initialCapacity, pageSizeBytes, false);
   }
 
   public BytesToBytesMap(
@@ -208,6 +217,7 @@ public final class BytesToBytesMap {
       boolean enablePerfMetrics) {
     this(
       taskMemoryManager,
+      SparkEnv.get() != null ? SparkEnv.get().blockManager() :  null,
       initialCapacity,
       0.70,
       pageSizeBytes,
@@ -219,61 +229,153 @@ public final class BytesToBytesMap {
    */
   public int numElements() { return numElements; }
 
-  public static final class BytesToBytesMapIterator implements Iterator<Location> {
+  public final class MapIterator implements Iterator<Location> {
 
-    private final int numRecords;
-    private final Iterator<MemoryBlock> dataPagesIterator;
+    private int numRecords;
     private final Location loc;
 
     private MemoryBlock currentPage = null;
-    private int currentRecordNumber = 0;
+    private int recordsInPage = 0;
     private Object pageBaseObject;
     private long offsetInPage;
 
     // If this iterator destructive or not. When it is true, it frees each page as it moves onto
     // next one.
     private boolean destructive = false;
-    private BytesToBytesMap bmap;
+    private UnsafeSorterSpillReader reader = null;
 
-    private BytesToBytesMapIterator(
-        int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc,
-        boolean destructive, BytesToBytesMap bmap) {
+    private MapIterator(int numRecords, Location loc, boolean destructive) {
       this.numRecords = numRecords;
-      this.dataPagesIterator = dataPagesIterator;
       this.loc = loc;
       this.destructive = destructive;
-      this.bmap = bmap;
-      if (dataPagesIterator.hasNext()) {
-        advanceToNextPage();
+      if (destructive) {
+        destructiveIterator = this;
       }
     }
 
     private void advanceToNextPage() {
-      if (destructive && currentPage != null) {
-        dataPagesIterator.remove();
-        this.bmap.taskMemoryManager.freePage(currentPage);
+      synchronized (this) {
+        int nextIdx = dataPages.indexOf(currentPage) + 1;
+        if (destructive && currentPage != null) {
+          dataPages.remove(currentPage);
+          freePage(currentPage);
+          nextIdx --;
+        }
+        if (dataPages.size() > nextIdx) {
+          currentPage = dataPages.get(nextIdx);
+          pageBaseObject = currentPage.getBaseObject();
+          offsetInPage = currentPage.getBaseOffset();
+          recordsInPage = Platform.getInt(pageBaseObject, offsetInPage);
+          offsetInPage += 4;
+        } else {
+          currentPage = null;
+          if (reader != null) {
+            // remove the spill file from disk
+            File file = spillWriters.removeFirst().getFile();
+            if (file != null && file.exists()) {
+              if (!file.delete()) {
+                logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+              }
+            }
+          }
+          try {
+            reader = spillWriters.getFirst().getReader(blockManager);
+            recordsInPage = -1;
+          } catch (IOException e) {
+            // Scala iterator does not handle exception
+            Platform.throwException(e);
+          }
+        }
       }
-      currentPage = dataPagesIterator.next();
-      pageBaseObject = currentPage.getBaseObject();
-      offsetInPage = currentPage.getBaseOffset();
     }
 
     @Override
     public boolean hasNext() {
-      return currentRecordNumber != numRecords;
+      if (numRecords == 0) {
+        if (reader != null) {
+          // remove the spill file from disk
+          File file = spillWriters.removeFirst().getFile();
+          if (file != null && file.exists()) {
+            if (!file.delete()) {
+              logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+            }
+          }
+        }
+      }
+      return numRecords > 0;
     }
 
     @Override
     public Location next() {
-      int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
-      if (totalLength == END_OF_PAGE_MARKER) {
+      if (recordsInPage == 0) {
         advanceToNextPage();
-        totalLength = Platform.getInt(pageBaseObject, offsetInPage);
       }
-      loc.with(currentPage, offsetInPage);
-      offsetInPage += 4 + totalLength;
-      currentRecordNumber++;
-      return loc;
+      numRecords--;
+      if (currentPage != null) {
+        int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
+        loc.with(currentPage, offsetInPage);
+        offsetInPage += 4 + totalLength;
+        recordsInPage --;
+        return loc;
+      } else {
+        assert(reader != null);
+        if (!reader.hasNext()) {
+          advanceToNextPage();
+        }
+        try {
+          reader.loadNext();
+        } catch (IOException e) {
+          // Scala iterator does not handle exception
+          Platform.throwException(e);
+        }
+        loc.with(reader.getBaseObject(), reader.getBaseOffset(), reader.getRecordLength());
+        return loc;
+      }
+    }
+
+    public long spill(long numBytes) throws IOException {
+      synchronized (this) {
+        if (!destructive || dataPages.size() == 1) {
+          return 0L;
+        }
+
+        // TODO: use existing ShuffleWriteMetrics
+        ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
+
+        long released = 0L;
+        while (dataPages.size() > 0) {
+          MemoryBlock block = dataPages.getLast();
+          // The currentPage is used, cannot be released
+          if (block == currentPage) {
+            break;
+          }
+
+          Object base = block.getBaseObject();
+          long offset = block.getBaseOffset();
+          int numRecords = Platform.getInt(base, offset);
+          offset += 4;
+          final UnsafeSorterSpillWriter writer =
+            new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords);
+          while (numRecords > 0) {
+            int length = Platform.getInt(base, offset);
+            writer.write(base, offset + 4, length, 0);
+            offset += 4 + length;
+            numRecords--;
+          }
+          writer.close();
+          spillWriters.add(writer);
+
+          dataPages.removeLast();
+          released += block.size();
+          freePage(block);
+
+          if (released >= numBytes) {
+            break;
+          }
+        }
+
+        return released;
+      }
     }
 
     @Override
@@ -290,8 +392,8 @@ public final class BytesToBytesMap {
    * If any other lookups or operations are performed on this map while iterating over it, including
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
-  public BytesToBytesMapIterator iterator() {
-    return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this);
+  public MapIterator iterator() {
+    return new MapIterator(numElements, loc, false);
   }
 
   /**
@@ -304,8 +406,8 @@ public final class BytesToBytesMap {
    * If any other lookups or operations are performed on this map while iterating over it, including
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
-  public BytesToBytesMapIterator destructiveIterator() {
-    return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this);
+  public MapIterator destructiveIterator() {
+    return new MapIterator(numElements, loc, true);
   }
 
   /**
@@ -314,11 +416,8 @@ public final class BytesToBytesMap {
    *
    * This function always return the same {@link Location} instance to avoid object allocation.
    */
-  public Location lookup(
-      Object keyBaseObject,
-      long keyBaseOffset,
-      int keyRowLengthBytes) {
-    safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc);
+  public Location lookup(Object keyBase, long keyOffset, int keyLength) {
+    safeLookup(keyBase, keyOffset, keyLength, loc);
     return loc;
   }
 
@@ -327,18 +426,14 @@ public final class BytesToBytesMap {
    *
    * This is a thread-safe version of `lookup`, could be used by multiple threads.
    */
-  public void safeLookup(
-      Object keyBaseObject,
-      long keyBaseOffset,
-      int keyRowLengthBytes,
-      Location loc) {
+  public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) {
     assert(bitset != null);
     assert(longArray != null);
 
     if (enablePerfMetrics) {
       numKeyLookups++;
     }
-    final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes);
+    final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength);
     int pos = hashcode & mask;
     int step = 1;
     while (true) {
@@ -354,16 +449,16 @@ public final class BytesToBytesMap {
         if ((int) (stored) == hashcode) {
           // Full hash code matches.  Let's compare the keys for equality.
           loc.with(pos, hashcode, true);
-          if (loc.getKeyLength() == keyRowLengthBytes) {
+          if (loc.getKeyLength() == keyLength) {
             final MemoryLocation keyAddress = loc.getKeyAddress();
-            final Object storedKeyBaseObject = keyAddress.getBaseObject();
-            final long storedKeyBaseOffset = keyAddress.getBaseOffset();
+            final Object storedkeyBase = keyAddress.getBaseObject();
+            final long storedkeyOffset = keyAddress.getBaseOffset();
             final boolean areEqual = ByteArrayMethods.arrayEquals(
-              keyBaseObject,
-              keyBaseOffset,
-              storedKeyBaseObject,
-              storedKeyBaseOffset,
-              keyRowLengthBytes
+              keyBase,
+              keyOffset,
+              storedkeyBase,
+              storedkeyOffset,
+              keyLength
             );
             if (areEqual) {
               return;
@@ -410,18 +505,18 @@ public final class BytesToBytesMap {
         taskMemoryManager.getOffsetInPage(fullKeyAddress));
     }
 
-    private void updateAddressesAndSizes(final Object page, final long offsetInPage) {
-      long position = offsetInPage;
-      final int totalLength = Platform.getInt(page, position);
+    private void updateAddressesAndSizes(final Object base, final long offset) {
+      long position = offset;
+      final int totalLength = Platform.getInt(base, position);
       position += 4;
-      keyLength = Platform.getInt(page, position);
+      keyLength = Platform.getInt(base, position);
       position += 4;
       valueLength = totalLength - keyLength - 4;
 
-      keyMemoryLocation.setObjAndOffset(page, position);
+      keyMemoryLocation.setObjAndOffset(base, position);
 
       position += keyLength;
-      valueMemoryLocation.setObjAndOffset(page, position);
+      valueMemoryLocation.setObjAndOffset(base, position);
     }
 
     private Location with(int pos, int keyHashcode, boolean isDefined) {
@@ -444,6 +539,19 @@ public final class BytesToBytesMap {
     }
 
     /**
+     * This is only used for spilling
+     */
+    private Location with(Object base, long offset, int length) {
+      this.isDefined = true;
+      this.memoryPage = null;
+      keyLength = Platform.getInt(base, offset);
+      valueLength = length - 4 - keyLength;
+      keyMemoryLocation.setObjAndOffset(base, offset + 4);
+      valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength);
+      return this;
+    }
+
+    /**
      * Returns the memory page that contains the current record.
      * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}.
      */
@@ -517,9 +625,9 @@ public final class BytesToBytesMap {
      * As an example usage, here's the proper way to store a new key:
      * </p>
      * <pre>
-     *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
+     *   Location loc = map.lookup(keyBase, keyOffset, keyLength);
      *   if (!loc.isDefined()) {
-     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+     *     if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
      *       // handle failure to grow map (by spilling, for example)
      *     }
      *   }
@@ -531,113 +639,59 @@ public final class BytesToBytesMap {
      * @return true if the put() was successful and false if the put() failed because memory could
      *         not be acquired.
      */
-    public boolean putNewKey(
-        Object keyBaseObject,
-        long keyBaseOffset,
-        int keyLengthBytes,
-        Object valueBaseObject,
-        long valueBaseOffset,
-        int valueLengthBytes) {
+    public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
+        Object valueBase, long valueOffset, int valueLength) {
       assert (!isDefined) : "Can only set value once for a key";
-      assert (keyLengthBytes % 8 == 0);
-      assert (valueLengthBytes % 8 == 0);
+      assert (keyLength % 8 == 0);
+      assert (valueLength % 8 == 0);
       assert(bitset != null);
       assert(longArray != null);
 
-      if (numElements == MAX_CAPACITY) {
-        throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+      if (numElements == MAX_CAPACITY || !canGrowArray) {
+        return false;
       }
 
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
       // (8 byte key length) (key) (value)
-      final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
-
-      // --- Figure out where to insert the new record ---------------------------------------------
-
-      final MemoryBlock dataPage;
-      final Object dataPageBaseObject;
-      final long dataPageInsertOffset;
-      boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
-      if (useOverflowPage) {
-        // The record is larger than the page size, so allocate a special overflow page just to hold
-        // that record.
-        final long overflowPageSize = requiredSize + 8;
-        MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-        if (overflowPage == null) {
-          logger.debug("Failed to acquire {} bytes of memory", overflowPageSize);
+      final long recordLength = 8 + keyLength + valueLength;
+      if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
+        if (!acquireNewPage(recordLength + 4L)) {
           return false;
         }
-        dataPages.add(overflowPage);
-        dataPage = overflowPage;
-        dataPageBaseObject = overflowPage.getBaseObject();
-        dataPageInsertOffset = overflowPage.getBaseOffset();
-      } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
-        // The record can fit in a data page, but either we have not allocated any pages yet or
-        // the current page does not have enough space.
-        if (currentDataPage != null) {
-          // There wasn't enough space in the current page, so write an end-of-page marker:
-          final Object pageBaseObject = currentDataPage.getBaseObject();
-          final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
-          Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
-        }
-        if (!acquireNewPage()) {
-          return false;
-        }
-        dataPage = currentDataPage;
-        dataPageBaseObject = currentDataPage.getBaseObject();
-        dataPageInsertOffset = currentDataPage.getBaseOffset();
-      } else {
-        // There is enough space in the current data page.
-        dataPage = currentDataPage;
-        dataPageBaseObject = currentDataPage.getBaseObject();
-        dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
       }
 
       // --- Append the key and value data to the current data page --------------------------------
-
-      long insertCursor = dataPageInsertOffset;
-
-      // Compute all of our offsets up-front:
-      final long recordOffset = insertCursor;
-      insertCursor += 4;
-      final long keyLengthOffset = insertCursor;
-      insertCursor += 4;
-      final long keyDataOffsetInPage = insertCursor;
-      insertCursor += keyLengthBytes;
-      final long valueDataOffsetInPage = insertCursor;
-      insertCursor += valueLengthBytes; // word used to store the value size
-
-      Platform.putInt(dataPageBaseObject, recordOffset,
-        keyLengthBytes + valueLengthBytes + 4);
-      Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
-      // Copy the key
-      Platform.copyMemory(
-        keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
-      // Copy the value
-      Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
-        valueDataOffsetInPage, valueLengthBytes);
-
-      // --- Update bookeeping data structures -----------------------------------------------------
-
-      if (useOverflowPage) {
-        // Store the end-of-page marker at the end of the data page
-        Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
-      } else {
-        pageCursor += requiredSize;
-      }
-
+      final Object base = currentPage.getBaseObject();
+      long offset = currentPage.getBaseOffset() + pageCursor;
+      final long recordOffset = offset;
+      Platform.putInt(base, offset, keyLength + valueLength + 4);
+      Platform.putInt(base, offset + 4, keyLength);
+      offset += 8;
+      Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
+      offset += keyLength;
+      Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
+
+      // --- Update bookkeeping data structures -----------------------------------------------------
+      offset = currentPage.getBaseOffset();
+      Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
+      pageCursor += recordLength;
       numElements++;
       bitset.set(pos);
       final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
-        dataPage, recordOffset);
+        currentPage, recordOffset);
       longArray.set(pos * 2, storedKeyAddress);
       longArray.set(pos * 2 + 1, keyHashcode);
       updateAddressesAndSizes(storedKeyAddress);
       isDefined = true;
+
       if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
-        growAndRehash();
+        try {
+          growAndRehash();
+        } catch (OutOfMemoryError oom) {
+          canGrowArray = false;
+        }
       }
       return true;
     }
@@ -647,18 +701,26 @@ public final class BytesToBytesMap {
    * Acquire a new page from the memory manager.
    * @return whether there is enough space to allocate the new page.
    */
-  private boolean acquireNewPage() {
-    MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
-    if (newPage == null) {
-      logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+  private boolean acquireNewPage(long required) {
+    try {
+      currentPage = allocatePage(required);
+    } catch (OutOfMemoryError e) {
       return false;
     }
-    dataPages.add(newPage);
-    pageCursor = 0;
-    currentDataPage = newPage;
+    dataPages.add(currentPage);
+    Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0);
+    pageCursor = 4;
     return true;
   }
 
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this && destructiveIterator != null) {
+      return destructiveIterator.spill(size);
+    }
+    return 0L;
+  }
+
   /**
    * Allocate new data structures for this map. When calling this outside of the constructor,
    * make sure to keep references to the old data structures so that you can free them.
@@ -670,6 +732,7 @@ public final class BytesToBytesMap {
     // The capacity needs to be divisible by 64 so that our bit set can be sized properly
     capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
     assert (capacity <= MAX_CAPACITY);
+    acquireMemory(capacity * 16);
     longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
     bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
 
@@ -678,22 +741,42 @@ public final class BytesToBytesMap {
   }
 
   /**
+   * Free the memory used by longArray.
+   */
+  public void freeArray() {
+    updatePeakMemoryUsed();
+    if (longArray != null) {
+      long used = longArray.memoryBlock().size();
+      longArray = null;
+      releaseMemory(used);
+      bitset = null;
+    }
+  }
+
+  /**
    * Free all allocated memory associated with this map, including the storage for keys and values
    * as well as the hash map array itself.
    *
    * This method is idempotent and can be called multiple times.
    */
   public void free() {
-    updatePeakMemoryUsed();
-    longArray = null;
-    bitset = null;
+    freeArray();
     Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
     while (dataPagesIterator.hasNext()) {
       MemoryBlock dataPage = dataPagesIterator.next();
       dataPagesIterator.remove();
-      taskMemoryManager.freePage(dataPage);
+      freePage(dataPage);
     }
     assert(dataPages.isEmpty());
+
+    while (!spillWriters.isEmpty()) {
+      File file = spillWriters.removeFirst().getFile();
+      if (file != null && file.exists()) {
+        if (!file.delete()) {
+          logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+        }
+      }
+    }
   }
 
   public TaskMemoryManager getTaskMemoryManager() {
@@ -782,7 +865,13 @@ public final class BytesToBytesMap {
     final int oldCapacity = (int) oldBitSet.capacity();
 
     // Allocate the new data structures
-    allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+    try {
+      allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+    } catch (OutOfMemoryError oom) {
+      longArray = oldLongArray;
+      bitset = oldBitSet;
+      throw oom;
+    }
 
     // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
     for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
@@ -806,6 +895,7 @@ public final class BytesToBytesMap {
         }
       }
     }
+    releaseMemory(oldLongArray.memoryBlock().size());
 
     if (enablePerfMetrics) {
       timeSpentResizingNs += System.nanoTime() - resizeStartTime;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/3] spark git commit: [SPARK-10342] [SPARK-10309] [SPARK-10474] [SPARK-10929] [SQL] Cooperative memory management

Posted by da...@apache.org.
http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index e317ea3..49a5a4b 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -17,39 +17,34 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import javax.annotation.Nullable;
 import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
 
-import javax.annotation.Nullable;
-
-import scala.runtime.AbstractFunction0;
-import scala.runtime.BoxedUnit;
-
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.storage.BlockManager;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.TaskCompletionListener;
 import org.apache.spark.util.Utils;
 
 /**
  * External sorter based on {@link UnsafeInMemorySorter}.
  */
-public final class UnsafeExternalSorter {
+public final class UnsafeExternalSorter extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
 
-  private final long pageSizeBytes;
   private final PrefixComparator prefixComparator;
   private final RecordComparator recordComparator;
-  private final int initialSize;
   private final TaskMemoryManager taskMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
@@ -69,14 +64,12 @@ public final class UnsafeExternalSorter {
   private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
 
   // These variables are reset after spilling:
-  @Nullable private UnsafeInMemorySorter inMemSorter;
-  // Whether the in-mem sorter is created internally, or passed in from outside.
-  // If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
-  private boolean isInMemSorterExternal = false;
+  @Nullable private volatile UnsafeInMemorySorter inMemSorter;
+
   private MemoryBlock currentPage = null;
-  private long currentPagePosition = -1;
-  private long freeSpaceInCurrentPage = 0;
+  private long pageCursor = -1;
   private long peakMemoryUsedBytes = 0;
+  private volatile SpillableIterator readingIterator = null;
 
   public static UnsafeExternalSorter createWithExistingInMemorySorter(
       TaskMemoryManager taskMemoryManager,
@@ -86,7 +79,7 @@ public final class UnsafeExternalSorter {
       PrefixComparator prefixComparator,
       int initialSize,
       long pageSizeBytes,
-      UnsafeInMemorySorter inMemorySorter) throws IOException {
+      UnsafeInMemorySorter inMemorySorter) {
     return new UnsafeExternalSorter(taskMemoryManager, blockManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
   }
@@ -98,7 +91,7 @@ public final class UnsafeExternalSorter {
       RecordComparator recordComparator,
       PrefixComparator prefixComparator,
       int initialSize,
-      long pageSizeBytes) throws IOException {
+      long pageSizeBytes) {
     return new UnsafeExternalSorter(taskMemoryManager, blockManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
   }
@@ -111,60 +104,41 @@ public final class UnsafeExternalSorter {
       PrefixComparator prefixComparator,
       int initialSize,
       long pageSizeBytes,
-      @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
+      @Nullable UnsafeInMemorySorter existingInMemorySorter) {
+    super(taskMemoryManager, pageSizeBytes);
     this.taskMemoryManager = taskMemoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
     this.recordComparator = recordComparator;
     this.prefixComparator = prefixComparator;
-    this.initialSize = initialSize;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
     // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.fileBufferSizeBytes = 32 * 1024;
-    this.pageSizeBytes = pageSizeBytes;
+    // TODO: metrics tracking + integration with shuffle write metrics
+    // need to connect the write metrics to task metrics so we count the spill IO somewhere.
     this.writeMetrics = new ShuffleWriteMetrics();
 
     if (existingInMemorySorter == null) {
-      initializeForWriting();
-      // Acquire a new page as soon as we construct the sorter to ensure that we have at
-      // least one page to work with. Otherwise, other operators in the same task may starve
-      // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter.
-      acquireNewPage();
+      this.inMemSorter =
+        new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
+      acquireMemory(inMemSorter.getMemoryUsage());
     } else {
-      this.isInMemSorterExternal = true;
       this.inMemSorter = existingInMemorySorter;
+      // will acquire after free the map
     }
+    this.peakMemoryUsedBytes = getMemoryUsage();
 
     // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
     // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
     // does not fully consume the sorter's output (e.g. sort followed by limit).
-    taskContext.addOnCompleteCallback(new AbstractFunction0<BoxedUnit>() {
-      @Override
-      public BoxedUnit apply() {
-        cleanupResources();
-        return null;
+    taskContext.addTaskCompletionListener(
+      new TaskCompletionListener() {
+        @Override
+        public void onTaskCompletion(TaskContext context) {
+          cleanupResources();
+        }
       }
-    });
-  }
-
-  // TODO: metrics tracking + integration with shuffle write metrics
-  // need to connect the write metrics to task metrics so we count the spill IO somewhere.
-
-  /**
-   * Allocates new sort data structures. Called when creating the sorter and after each spill.
-   */
-  private void initializeForWriting() throws IOException {
-    // Note: Do not track memory for the pointer array for now because of SPARK-10474.
-    // In more detail, in TungstenAggregate we only reserve a page, but when we fall back to
-    // sort-based aggregation we try to acquire a page AND a pointer array, which inevitably
-    // fails if all other memory is already occupied. It should be safe to not track the array
-    // because its memory footprint is frequently much smaller than that of a page. This is a
-    // temporary hack that we should address in 1.6.0.
-    // TODO: track the pointer array memory!
-    this.writeMetrics = new ShuffleWriteMetrics();
-    this.inMemSorter =
-      new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
-    this.isInMemSorterExternal = false;
+    );
   }
 
   /**
@@ -173,14 +147,27 @@ public final class UnsafeExternalSorter {
    */
   @VisibleForTesting
   public void closeCurrentPage() {
-    freeSpaceInCurrentPage = 0;
+    if (currentPage != null) {
+      pageCursor = currentPage.getBaseOffset() + currentPage.size();
+    }
   }
 
   /**
    * Sort and spill the current records in response to memory pressure.
    */
-  public void spill() throws IOException {
-    assert(inMemSorter != null);
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    if (trigger != this) {
+      if (readingIterator != null) {
+        return readingIterator.spill();
+      }
+      return 0L;
+    }
+
+    if (inMemSorter == null || inMemSorter.numRecords() <= 0) {
+      return 0L;
+    }
+
     logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
       Thread.currentThread().getId(),
       Utils.bytesToString(getMemoryUsage()),
@@ -202,6 +189,8 @@ public final class UnsafeExternalSorter {
         spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
       }
       spillWriter.close();
+
+      inMemSorter.reset();
     }
 
     final long spillSize = freeMemory();
@@ -210,7 +199,7 @@ public final class UnsafeExternalSorter {
     // written to disk. This also counts the space needed to store the sorter's pointer array.
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
 
-    initializeForWriting();
+    return spillSize;
   }
 
   /**
@@ -246,7 +235,7 @@ public final class UnsafeExternalSorter {
   }
 
   /**
-   * Free this sorter's in-memory data structures, including its data pages and pointer array.
+   * Free this sorter's data pages.
    *
    * @return the number of bytes freed.
    */
@@ -254,14 +243,12 @@ public final class UnsafeExternalSorter {
     updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
-      taskMemoryManager.freePage(block);
       memoryFreed += block.size();
+      freePage(block);
     }
-    // TODO: track in-memory sorter memory usage (SPARK-10474)
     allocatedPages.clear();
     currentPage = null;
-    currentPagePosition = -1;
-    freeSpaceInCurrentPage = 0;
+    pageCursor = 0;
     return memoryFreed;
   }
 
@@ -283,8 +270,15 @@ public final class UnsafeExternalSorter {
    * Frees this sorter's in-memory data structures and cleans up its spill files.
    */
   public void cleanupResources() {
-    deleteSpillFiles();
-    freeMemory();
+    synchronized (this) {
+      deleteSpillFiles();
+      freeMemory();
+      if (inMemSorter != null) {
+        long used = inMemSorter.getMemoryUsage();
+        inMemSorter = null;
+        releaseMemory(used);
+      }
+    }
   }
 
   /**
@@ -295,8 +289,28 @@ public final class UnsafeExternalSorter {
   private void growPointerArrayIfNecessary() throws IOException {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
-      // TODO: track the pointer array memory! (SPARK-10474)
-      inMemSorter.expandPointerArray();
+      long used = inMemSorter.getMemoryUsage();
+      long needed = used + inMemSorter.getMemoryToExpand();
+      try {
+        acquireMemory(needed);  // could trigger spilling
+      } catch (OutOfMemoryError e) {
+        // should have trigger spilling
+        assert(inMemSorter.hasSpaceForAnotherRecord());
+        return;
+      }
+      // check if spilling is triggered or not
+      if (inMemSorter.hasSpaceForAnotherRecord()) {
+        releaseMemory(needed);
+      } else {
+        try {
+          inMemSorter.expandPointerArray();
+          releaseMemory(used);
+        } catch (OutOfMemoryError oom) {
+          // Just in case that JVM had run out of memory
+          releaseMemory(needed);
+          spill();
+        }
+      }
     }
   }
 
@@ -304,101 +318,38 @@ public final class UnsafeExternalSorter {
    * Allocates more memory in order to insert an additional record. This will request additional
    * memory from the memory manager and spill if the requested memory can not be obtained.
    *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   * @param required the required space in the data page, in bytes, including space for storing
    *                      the record size. This must be less than or equal to the page size (records
    *                      that exceed the page size are handled via a different code path which uses
    *                      special overflow pages).
    */
-  private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
-    assert (requiredSpace <= pageSizeBytes);
-    if (requiredSpace > freeSpaceInCurrentPage) {
-      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
-        freeSpaceInCurrentPage);
-      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
-      // without using the free space at the end of the current page. We should also do this for
-      // BytesToBytesMap.
-      if (requiredSpace > pageSizeBytes) {
-        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
-          pageSizeBytes + ")");
-      } else {
-        acquireNewPage();
-      }
+  private void acquireNewPageIfNecessary(int required) {
+    if (currentPage == null ||
+      pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) {
+      // TODO: try to find space on previous pages
+      currentPage = allocatePage(required);
+      pageCursor = currentPage.getBaseOffset();
+      allocatedPages.add(currentPage);
     }
   }
 
   /**
-   * Acquire a new page from the memory manager.
-   *
-   * If there is not enough space to allocate the new page, spill all existing ones
-   * and try again. If there is still not enough space, report error to the caller.
-   */
-  private void acquireNewPage() throws IOException {
-    currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-    if (currentPage == null) {
-      spill();
-      currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-      if (currentPage == null) {
-        throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
-      }
-    }
-    currentPagePosition = currentPage.getBaseOffset();
-    freeSpaceInCurrentPage = pageSizeBytes;
-    allocatedPages.add(currentPage);
-  }
-
-  /**
    * Write a record to the sorter.
    */
-  public void insertRecord(
-      Object recordBaseObject,
-      long recordBaseOffset,
-      int lengthInBytes,
-      long prefix) throws IOException {
+  public void insertRecord(Object recordBase, long recordOffset, int length, long prefix)
+    throws IOException {
 
     growPointerArrayIfNecessary();
     // Need 4 bytes to store the record length.
-    final int totalSpaceRequired = lengthInBytes + 4;
-
-    // --- Figure out where to insert the new record ----------------------------------------------
-
-    final MemoryBlock dataPage;
-    long dataPagePosition;
-    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
-    if (useOverflowPage) {
-      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
-      // The record is larger than the page size, so allocate a special overflow page just to hold
-      // that record.
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-      if (overflowPage == null) {
-        spill();
-        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-        if (overflowPage == null) {
-          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
-        }
-      }
-      allocatedPages.add(overflowPage);
-      dataPage = overflowPage;
-      dataPagePosition = overflowPage.getBaseOffset();
-    } else {
-      // The record is small enough to fit in a regular data page, but the current page might not
-      // have enough space to hold it (or no pages have been allocated yet).
-      acquireNewPageIfNecessary(totalSpaceRequired);
-      dataPage = currentPage;
-      dataPagePosition = currentPagePosition;
-      // Update bookkeeping information
-      freeSpaceInCurrentPage -= totalSpaceRequired;
-      currentPagePosition += totalSpaceRequired;
-    }
-    final Object dataPageBaseObject = dataPage.getBaseObject();
-
-    // --- Insert the record ----------------------------------------------------------------------
-
-    final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
-    Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
-    dataPagePosition += 4;
-    Platform.copyMemory(
-      recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
+    final int required = length + 4;
+    acquireNewPageIfNecessary(required);
+
+    final Object base = currentPage.getBaseObject();
+    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+    Platform.putInt(base, pageCursor, length);
+    pageCursor += 4;
+    Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
+    pageCursor += length;
     assert(inMemSorter != null);
     inMemSorter.insertRecord(recordAddress, prefix);
   }
@@ -411,59 +362,24 @@ public final class UnsafeExternalSorter {
    *
    * record length = key length + value length + 4
    */
-  public void insertKVRecord(
-      Object keyBaseObj, long keyOffset, int keyLen,
-      Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException {
+  public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
+      Object valueBase, long valueOffset, int valueLen, long prefix)
+    throws IOException {
 
     growPointerArrayIfNecessary();
-    final int totalSpaceRequired = keyLen + valueLen + 4 + 4;
-
-    // --- Figure out where to insert the new record ----------------------------------------------
-
-    final MemoryBlock dataPage;
-    long dataPagePosition;
-    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
-    if (useOverflowPage) {
-      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
-      // The record is larger than the page size, so allocate a special overflow page just to hold
-      // that record.
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-      if (overflowPage == null) {
-        spill();
-        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-        if (overflowPage == null) {
-          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
-        }
-      }
-      allocatedPages.add(overflowPage);
-      dataPage = overflowPage;
-      dataPagePosition = overflowPage.getBaseOffset();
-    } else {
-      // The record is small enough to fit in a regular data page, but the current page might not
-      // have enough space to hold it (or no pages have been allocated yet).
-      acquireNewPageIfNecessary(totalSpaceRequired);
-      dataPage = currentPage;
-      dataPagePosition = currentPagePosition;
-      // Update bookkeeping information
-      freeSpaceInCurrentPage -= totalSpaceRequired;
-      currentPagePosition += totalSpaceRequired;
-    }
-    final Object dataPageBaseObject = dataPage.getBaseObject();
-
-    // --- Insert the record ----------------------------------------------------------------------
-
-    final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
-    Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4);
-    dataPagePosition += 4;
-
-    Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen);
-    dataPagePosition += 4;
-
-    Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen);
-    dataPagePosition += keyLen;
-
-    Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen);
+    final int required = keyLen + valueLen + 4 + 4;
+    acquireNewPageIfNecessary(required);
+
+    final Object base = currentPage.getBaseObject();
+    final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+    Platform.putInt(base, pageCursor, keyLen + valueLen + 4);
+    pageCursor += 4;
+    Platform.putInt(base, pageCursor, keyLen);
+    pageCursor += 4;
+    Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen);
+    pageCursor += keyLen;
+    Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen);
+    pageCursor += valueLen;
 
     assert(inMemSorter != null);
     inMemSorter.insertRecord(recordAddress, prefix);
@@ -475,10 +391,10 @@ public final class UnsafeExternalSorter {
    */
   public UnsafeSorterIterator getSortedIterator() throws IOException {
     assert(inMemSorter != null);
-    final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
-    int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
+    readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
+    int numIteratorsToMerge = spillWriters.size() + (readingIterator.hasNext() ? 1 : 0);
     if (spillWriters.isEmpty()) {
-      return inMemoryIterator;
+      return readingIterator;
     } else {
       final UnsafeSorterSpillMerger spillMerger =
         new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
@@ -486,9 +402,113 @@ public final class UnsafeExternalSorter {
         spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
       }
       spillWriters.clear();
-      spillMerger.addSpillIfNotEmpty(inMemoryIterator);
+      spillMerger.addSpillIfNotEmpty(readingIterator);
 
       return spillMerger.getSortedIterator();
     }
   }
+
+  /**
+   * An UnsafeSorterIterator that support spilling.
+   */
+  class SpillableIterator extends UnsafeSorterIterator {
+    private UnsafeSorterIterator upstream;
+    private UnsafeSorterIterator nextUpstream = null;
+    private MemoryBlock lastPage = null;
+    private boolean loaded = false;
+    private int numRecords = 0;
+
+    public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+      this.upstream = inMemIterator;
+      this.numRecords = inMemIterator.numRecordsLeft();
+    }
+
+    public long spill() throws IOException {
+      synchronized (this) {
+        if (!(upstream instanceof UnsafeInMemorySorter.SortedIterator && nextUpstream == null
+          && numRecords > 0)) {
+          return 0L;
+        }
+
+        UnsafeInMemorySorter.SortedIterator inMemIterator =
+          ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
+
+        final UnsafeSorterSpillWriter spillWriter =
+          new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
+        while (inMemIterator.hasNext()) {
+          inMemIterator.loadNext();
+          final Object baseObject = inMemIterator.getBaseObject();
+          final long baseOffset = inMemIterator.getBaseOffset();
+          final int recordLength = inMemIterator.getRecordLength();
+          spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix());
+        }
+        spillWriter.close();
+        spillWriters.add(spillWriter);
+        nextUpstream = spillWriter.getReader(blockManager);
+
+        long released = 0L;
+        synchronized (UnsafeExternalSorter.this) {
+          // release the pages except the one that is used
+          for (MemoryBlock page : allocatedPages) {
+            if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) {
+              released += page.size();
+              freePage(page);
+            } else {
+              lastPage = page;
+            }
+          }
+          allocatedPages.clear();
+        }
+        return released;
+      }
+    }
+
+    @Override
+    public boolean hasNext() {
+      return numRecords > 0;
+    }
+
+    @Override
+    public void loadNext() throws IOException {
+      synchronized (this) {
+        loaded = true;
+        if (nextUpstream != null) {
+          // Just consumed the last record from in memory iterator
+          if (lastPage != null) {
+            freePage(lastPage);
+            lastPage = null;
+          }
+          upstream = nextUpstream;
+          nextUpstream = null;
+
+          assert(inMemSorter != null);
+          long used = inMemSorter.getMemoryUsage();
+          inMemSorter = null;
+          releaseMemory(used);
+        }
+        numRecords--;
+        upstream.loadNext();
+      }
+    }
+
+    @Override
+    public Object getBaseObject() {
+      return upstream.getBaseObject();
+    }
+
+    @Override
+    public long getBaseOffset() {
+      return upstream.getBaseOffset();
+    }
+
+    @Override
+    public int getRecordLength() {
+      return upstream.getRecordLength();
+    }
+
+    @Override
+    public long getKeyPrefix() {
+      return upstream.getKeyPrefix();
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 5aad72c..1480f06 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -70,12 +70,12 @@ public final class UnsafeInMemorySorter {
    * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
    * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
    */
-  private long[] pointerArray;
+  private long[] array;
 
   /**
    * The position in the sort buffer where new records can be inserted.
    */
-  private int pointerArrayInsertPosition = 0;
+  private int pos = 0;
 
   public UnsafeInMemorySorter(
       final TaskMemoryManager memoryManager,
@@ -83,37 +83,43 @@ public final class UnsafeInMemorySorter {
       final PrefixComparator prefixComparator,
       int initialSize) {
     assert (initialSize > 0);
-    this.pointerArray = new long[initialSize * 2];
+    this.array = new long[initialSize * 2];
     this.memoryManager = memoryManager;
     this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
     this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
   }
 
+  public void reset() {
+    pos = 0;
+  }
+
   /**
    * @return the number of records that have been inserted into this sorter.
    */
   public int numRecords() {
-    return pointerArrayInsertPosition / 2;
+    return pos / 2;
   }
 
-  public long getMemoryUsage() {
-    return pointerArray.length * 8L;
+  private int newLength() {
+    return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
+  }
+
+  public long getMemoryToExpand() {
+    return (long) (newLength() - array.length) * 8L;
   }
 
-  static long getMemoryRequirementsForPointerArray(long numEntries) {
-    return numEntries * 2L * 8L;
+  public long getMemoryUsage() {
+    return array.length * 8L;
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pointerArrayInsertPosition + 2 < pointerArray.length;
+    return pos + 2 <= array.length;
   }
 
   public void expandPointerArray() {
-    final long[] oldArray = pointerArray;
-    // Guard against overflow:
-    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
-    pointerArray = new long[newLength];
-    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+    final long[] oldArray = array;
+    array = new long[newLength()];
+    System.arraycopy(oldArray, 0, array, 0, oldArray.length);
   }
 
   /**
@@ -127,10 +133,10 @@ public final class UnsafeInMemorySorter {
     if (!hasSpaceForAnotherRecord()) {
       expandPointerArray();
     }
-    pointerArray[pointerArrayInsertPosition] = recordPointer;
-    pointerArrayInsertPosition++;
-    pointerArray[pointerArrayInsertPosition] = keyPrefix;
-    pointerArrayInsertPosition++;
+    array[pos] = recordPointer;
+    pos++;
+    array[pos] = keyPrefix;
+    pos++;
   }
 
   public static final class SortedIterator extends UnsafeSorterIterator {
@@ -153,11 +159,25 @@ public final class UnsafeInMemorySorter {
       this.sortBuffer = sortBuffer;
     }
 
+    public SortedIterator clone () {
+      SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer);
+      iter.position = position;
+      iter.baseObject = baseObject;
+      iter.baseOffset = baseOffset;
+      iter.keyPrefix = keyPrefix;
+      iter.recordLength = recordLength;
+      return iter;
+    }
+
     @Override
     public boolean hasNext() {
       return position < sortBufferInsertPosition;
     }
 
+    public int numRecordsLeft() {
+      return (sortBufferInsertPosition - position) / 2;
+    }
+
     @Override
     public void loadNext() {
       // This pointer points to a 4-byte record length, followed by the record's bytes
@@ -187,7 +207,7 @@ public final class UnsafeInMemorySorter {
    * {@code next()} will return the same mutable object.
    */
   public SortedIterator getSortedIterator() {
-    sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
-    return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
+    sorter.sort(array, 0, pos / 2, sortComparator);
+    return new SortedIterator(memoryManager, pos, array);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 501dfe7..039e940 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -20,18 +20,18 @@ package org.apache.spark.util.collection.unsafe.sort;
 import java.io.*;
 
 import com.google.common.io.ByteStreams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.unsafe.Platform;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 /**
  * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
  * of the file format).
  */
-final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+public final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
   private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class);
 
   private final File file;

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index e59a84f..234e211 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -35,7 +35,7 @@ import org.apache.spark.unsafe.Platform;
  *
  *   [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
  */
-final class UnsafeSorterSpillWriter {
+public final class UnsafeSorterSpillWriter {
 
   static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
 

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
index 6c9a71c..b0cf269 100644
--- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import com.google.common.annotations.VisibleForTesting
 
+import org.apache.spark.util.Utils
 import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging}
 import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
 import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -215,8 +216,12 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte
   final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized {
     val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L)
     if (curMem < numBytes) {
-      throw new SparkException(
-        s"Internal error: release called on $numBytes bytes but task only has $curMem")
+      if (Utils.isTesting) {
+        throw new SparkException(
+          s"Internal error: release called on $numBytes bytes but task only has $curMem")
+      } else {
+        logWarning(s"Internal error: release called on $numBytes bytes but task only has $curMem")
+      }
     }
     if (executionMemoryForTask.contains(taskAttemptId)) {
       executionMemoryForTask(taskAttemptId) -= numBytes

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index a76891a..9e00262 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging {
     if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
       // Claim up to double our current memory from the shuffle memory pool
       val amountToRequest = 2 * currentMemory - myMemoryThreshold
-      val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest)
+      val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null)
       myMemoryThreshold += granted
       // If we were granted too little memory to grow further (either tryToAcquire returned 0,
       // or we already had more memory than myMemoryThreshold), spill the current collection
@@ -107,7 +107,7 @@ private[spark] trait Spillable[C] extends Logging {
    */
   def releaseMemory(): Unit = {
     // The amount we requested does not include the initial memory tracking threshold
-    taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold)
+    taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null)
     myMemoryThreshold = initialMemoryThreshold
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index f381db0..dab7b05 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.memory;
 
+import java.io.IOException;
+
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -25,19 +27,40 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
 
 public class TaskMemoryManagerSuite {
 
+  class TestMemoryConsumer extends MemoryConsumer {
+    TestMemoryConsumer(TaskMemoryManager memoryManager) {
+      super(memoryManager);
+    }
+
+    @Override
+    public long spill(long size, MemoryConsumer trigger) throws IOException {
+      long used = getUsed();
+      releaseMemory(used);
+      return used;
+    }
+
+    void use(long size) {
+      acquireMemory(size);
+    }
+
+    void free(long size) {
+      releaseMemory(size);
+    }
+  }
+
   @Test
   public void leakedPageMemoryIsDetected() {
     final TaskMemoryManager manager = new TaskMemoryManager(
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
-    manager.allocatePage(4096);  // leak memory
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    manager.allocatePage(4096, null);  // leak memory
     Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
   }
 
   @Test
   public void encodePageNumberAndOffsetOffHeap() {
     final TaskMemoryManager manager = new TaskMemoryManager(
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
-    final MemoryBlock dataPage = manager.allocatePage(256);
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
+    final MemoryBlock dataPage = manager.allocatePage(256, null);
     // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
     // encode. This test exercises that corner-case:
     final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
@@ -49,11 +72,53 @@ public class TaskMemoryManagerSuite {
   @Test
   public void encodePageNumberAndOffsetOnHeap() {
     final TaskMemoryManager manager = new TaskMemoryManager(
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
-    final MemoryBlock dataPage = manager.allocatePage(256);
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    final MemoryBlock dataPage = manager.allocatePage(256, null);
     final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
     Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
     Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
   }
 
+  @Test
+  public void cooperativeSpilling() {
+    final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
+    memoryManager.limit(100);
+    final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0);
+
+    TestMemoryConsumer c1 = new TestMemoryConsumer(manager);
+    TestMemoryConsumer c2 = new TestMemoryConsumer(manager);
+    c1.use(100);
+    assert(c1.getUsed() == 100);
+    c2.use(100);
+    assert(c2.getUsed() == 100);
+    assert(c1.getUsed() == 0);  // spilled
+    c1.use(100);
+    assert(c1.getUsed() == 100);
+    assert(c2.getUsed() == 0);  // spilled
+
+    c1.use(50);
+    assert(c1.getUsed() == 50);  // spilled
+    assert(c2.getUsed() == 0);
+    c2.use(50);
+    assert(c1.getUsed() == 50);
+    assert(c2.getUsed() == 50);
+
+    c1.use(100);
+    assert(c1.getUsed() == 100);
+    assert(c2.getUsed() == 0);  // spilled
+
+    c1.free(20);
+    assert(c1.getUsed() == 80);
+    c2.use(10);
+    assert(c1.getUsed() == 80);
+    assert(c2.getUsed() == 10);
+    c2.use(100);
+    assert(c2.getUsed() == 100);
+    assert(c1.getUsed() == 0);  // spilled
+
+    c1.free(0);
+    c2.free(100);
+    assert(manager.cleanUpAllAllocatedMemory() == 0);
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index 7fb2f92..9a43f1f 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -17,25 +17,29 @@
 
 package org.apache.spark.shuffle.sort;
 
-import org.apache.spark.shuffle.sort.PackedRecordPointer;
+import java.io.IOException;
+
 import org.junit.Test;
-import static org.junit.Assert.*;
 
 import org.apache.spark.SparkConf;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.memory.TaskMemoryManager;
-import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PARTITION_ID;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 
 public class PackedRecordPointerSuite {
 
   @Test
-  public void heap() {
+  public void heap() throws IOException {
     final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
-    final MemoryBlock page0 = memoryManager.allocatePage(128);
-    final MemoryBlock page1 = memoryManager.allocatePage(128);
+      new TaskMemoryManager(new TestMemoryManager(conf), 0);
+    final MemoryBlock page0 = memoryManager.allocatePage(128, null);
+    final MemoryBlock page1 = memoryManager.allocatePage(128, null);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
       page1.getBaseOffset() + 42);
     PackedRecordPointer packedPointer = new PackedRecordPointer();
@@ -49,12 +53,12 @@ public class PackedRecordPointerSuite {
   }
 
   @Test
-  public void offHeap() {
+  public void offHeap() throws IOException {
     final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
-    final MemoryBlock page0 = memoryManager.allocatePage(128);
-    final MemoryBlock page1 = memoryManager.allocatePage(128);
+      new TaskMemoryManager(new TestMemoryManager(conf), 0);
+    final MemoryBlock page0 = memoryManager.allocatePage(128, null);
+    final MemoryBlock page1 = memoryManager.allocatePage(128, null);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
       page1.getBaseOffset() + 42);
     PackedRecordPointer packedPointer = new PackedRecordPointer();

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 5049a53..2293b1b 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -26,7 +26,7 @@ import org.junit.Test;
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.SparkConf;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.memory.TaskMemoryManager;
 
@@ -60,8 +60,8 @@ public class ShuffleInMemorySorterSuite {
     };
     final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
-    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+      new TaskMemoryManager(new TestMemoryManager(conf), 0);
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
     final Object baseObject = dataPage.getBaseObject();
     final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
     final HashPartitioner hashPartitioner = new HashPartitioner(4);

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index d659269..4763395 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -54,13 +54,14 @@ import org.apache.spark.serializer.*;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.storage.*;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 public class UnsafeShuffleWriterSuite {
 
   static final int NUM_PARTITITONS = 4;
+  TestMemoryManager memoryManager;
   TaskMemoryManager taskMemoryManager;
   final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
   File mergedOutputFile;
@@ -106,10 +107,11 @@ public class UnsafeShuffleWriterSuite {
     partitionSizesInMergedFile = null;
     spillFilesCreated.clear();
     conf = new SparkConf()
-      .set("spark.buffer.pageSize", "128m")
+      .set("spark.buffer.pageSize", "1m")
       .set("spark.unsafe.offHeap", "false");
     taskMetrics = new TaskMetrics();
-    taskMemoryManager =  new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
+    memoryManager = new TestMemoryManager(conf);
+    taskMemoryManager =  new TaskMemoryManager(memoryManager, 0);
 
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
     when(blockManager.getDiskWriter(
@@ -344,9 +346,7 @@ public class UnsafeShuffleWriterSuite {
     }
     assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
 
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
+    assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
     assertSpillFilesWereCleanedUp();
     ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
     assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
@@ -398,20 +398,14 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void writeEnoughDataToTriggerSpill() throws Exception {
-    taskMemoryManager = spy(taskMemoryManager);
-    doCallRealMethod() // initialize sort buffer
-      .doCallRealMethod() // allocate initial data page
-      .doReturn(0L) // deny request to allocate new page
-      .doCallRealMethod() // grant new sort buffer and data page
-      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
+    memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES);
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
-    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
-    for (int i = 0; i < 128 + 1; i++) {
+    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10];
+    for (int i = 0; i < 10 + 1; i++) {
       dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
     }
     writer.write(dataToWrite.iterator());
-    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
     assertEquals(2, spillFilesCreated.size());
     writer.stop(true);
     readRecordsFromFile();
@@ -426,19 +420,13 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
-    taskMemoryManager = spy(taskMemoryManager);
-    doCallRealMethod() // initialize sort buffer
-      .doCallRealMethod() // allocate initial data page
-      .doReturn(0L) // deny request to allocate new page
-      .doCallRealMethod() // grant new sort buffer and data page
-      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
+    memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16);
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
     for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
       dataToWrite.add(new Tuple2<Object, Object>(i, i));
     }
     writer.write(dataToWrite.iterator());
-    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
     assertEquals(2, spillFilesCreated.size());
     writer.stop(true);
     readRecordsFromFile();
@@ -473,11 +461,11 @@ public class UnsafeShuffleWriterSuite {
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
     dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new byte[1])));
     // We should be able to write a record that's right _at_ the max record size
-    final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
+    final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4];
     new Random(42).nextBytes(atMaxRecordSize);
     dataToWrite.add(new Tuple2<Object, Object>(2, ByteBuffer.wrap(atMaxRecordSize)));
     // Inserting a record that's larger than the max record size
-    final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
+    final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()];
     new Random(42).nextBytes(exceedsMaxRecordSize);
     dataToWrite.add(new Tuple2<Object, Object>(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
     writer.write(dataToWrite.iterator());
@@ -524,7 +512,7 @@ public class UnsafeShuffleWriterSuite {
       for (int i = 0; i < numRecordsPerPage * 10; i++) {
         writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
         newPeakMemory = writer.getPeakMemoryUsedBytes();
-        if (i % numRecordsPerPage == 0 && i != 0) {
+        if (i % numRecordsPerPage == 0) {
           // The first page is allocated in constructor, another page will be allocated after
           // every numRecordsPerPage records (peak memory should change).
           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 6e52496..92bd45e 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -17,40 +17,117 @@
 
 package org.apache.spark.unsafe.map;
 
-import java.lang.Exception;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
 import java.nio.ByteBuffer;
 import java.util.*;
 
-import org.apache.spark.memory.TaskMemoryManager;
-import org.junit.*;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.junit.Assert.*;
+import scala.Tuple2;
+import scala.Tuple2$;
+import scala.runtime.AbstractFunction1;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 import org.apache.spark.SparkConf;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.memory.*;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.util.Utils;
+
+import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.when;
 
 
 public abstract class AbstractBytesToBytesMapSuite {
 
   private final Random rand = new Random(42);
 
-  private GrantEverythingMemoryManager memoryManager;
+  private TestMemoryManager memoryManager;
   private TaskMemoryManager taskMemoryManager;
   private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
 
+  final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+  File tempDir;
+
+  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+
+  private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+    @Override
+    public OutputStream apply(OutputStream stream) {
+      return stream;
+    }
+  }
+
   @Before
   public void setup() {
     memoryManager =
-      new GrantEverythingMemoryManager(
+      new TestMemoryManager(
         new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()));
     taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
+
+    tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
+    spillFilesCreated.clear();
+    MockitoAnnotations.initMocks(this);
+    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+    when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
+      @Override
+      public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
+        TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
+        File file = File.createTempFile("spillFile", ".spill", tempDir);
+        spillFilesCreated.add(file);
+        return Tuple2$.MODULE$.apply(blockId, file);
+      }
+    });
+    when(blockManager.getDiskWriter(
+      any(BlockId.class),
+      any(File.class),
+      any(SerializerInstance.class),
+      anyInt(),
+      any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+      @Override
+      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+        Object[] args = invocationOnMock.getArguments();
+
+        return new DiskBlockObjectWriter(
+          (File) args[1],
+          (SerializerInstance) args[2],
+          (Integer) args[3],
+          new CompressStream(),
+          false,
+          (ShuffleWriteMetrics) args[4]
+        );
+      }
+    });
+    when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
+      .then(returnsSecondArg());
   }
 
   @After
   public void tearDown() {
+    Utils.deleteRecursively(tempDir);
+    tempDir = null;
+
     Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
     if (taskMemoryManager != null) {
       long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
@@ -415,9 +492,8 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void failureToAllocateFirstPage() {
-    memoryManager.markExecutionAsOutOfMemory();
+    memoryManager.limit(1024);  // longArray
     BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
-    memoryManager.markExecutionAsOutOfMemory();
     try {
       final long[] emptyArray = new long[0];
       final BytesToBytesMap.Location loc =
@@ -439,7 +515,7 @@ public abstract class AbstractBytesToBytesMapSuite {
       int i;
       for (i = 0; i < 127; i++) {
         if (i > 0) {
-          memoryManager.markExecutionAsOutOfMemory();
+          memoryManager.limit(0);
         }
         final long[] arr = new long[]{i};
         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
@@ -457,6 +533,44 @@ public abstract class AbstractBytesToBytesMapSuite {
   }
 
   @Test
+  public void spillInIterator() throws IOException {
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false);
+    try {
+      int i;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
+        loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+      }
+      BytesToBytesMap.MapIterator iter = map.iterator();
+      for (i = 0; i < 100; i++) {
+        iter.next();
+      }
+      // Non-destructive iterator is not spillable
+      Assert.assertEquals(0, iter.spill(1024L * 10));
+      for (i = 100; i < 1024; i++) {
+        iter.next();
+      }
+
+      BytesToBytesMap.MapIterator iter2 = map.destructiveIterator();
+      for (i = 0; i < 100; i++) {
+        iter2.next();
+      }
+      Assert.assertTrue(iter2.spill(1024) >= 1024);
+      for (i = 100; i < 1024; i++) {
+        iter2.next();
+      }
+      assertFalse(iter2.hasNext());
+    } finally {
+      map.free();
+      for (File spillFile : spillFilesCreated) {
+        assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+          spillFile.exists());
+      }
+    }
+  }
+
+  @Test
   public void initialCapacityBoundsChecking() {
     try {
       new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES);
@@ -500,7 +614,7 @@ public abstract class AbstractBytesToBytesMapSuite {
           Platform.LONG_ARRAY_OFFSET,
           8);
         newPeakMemory = map.getPeakMemoryUsedBytes();
-        if (i % numRecordsPerPage == 0 && i > 0) {
+        if (i % numRecordsPerPage == 0) {
           // We allocated a new page for this record, so peak memory should change
           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
         } else {
@@ -519,11 +633,4 @@ public abstract class AbstractBytesToBytesMapSuite {
     }
   }
 
-  @Test
-  public void testAcquirePageInConstructor() {
-    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
-    assertEquals(1, map.getNumDataPages());
-    map.free();
-  }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 94d50b9..cfead0e 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -36,28 +36,29 @@ import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
-import static org.hamcrest.Matchers.greaterThanOrEqualTo;
-import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsSecondArg;
-import static org.mockito.Answers.RETURNS_SMART_NULLS;
-import static org.mockito.Mockito.*;
 
 import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.storage.*;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
 public class UnsafeExternalSorterSuite {
 
   final LinkedList<File> spillFilesCreated = new LinkedList<File>();
-  final GrantEverythingMemoryManager memoryManager =
-    new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
+  final TestMemoryManager memoryManager =
+    new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
   final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
   // Use integer comparison for comparing prefixes (which are partition ids, in this case)
   final PrefixComparator prefixComparator = new PrefixComparator() {
@@ -86,7 +87,7 @@ public class UnsafeExternalSorterSuite {
   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
 
 
-  private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m");
+  private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
 
   private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
     @Override
@@ -233,7 +234,7 @@ public class UnsafeExternalSorterSuite {
       insertNumber(sorter, numRecords - i);
     }
     assertEquals(1, sorter.getNumberOfAllocatedPages());
-    memoryManager.markExecutionAsOutOfMemory();
+    memoryManager.markExecutionAsOutOfMemoryOnce();
     // The insertion of this record should trigger a spill:
     insertNumber(sorter, 0);
     // Ensure that spill files were created
@@ -312,6 +313,62 @@ public class UnsafeExternalSorterSuite {
   }
 
   @Test
+  public void forcedSpillingWithReadIterator() throws Exception {
+    final UnsafeExternalSorter sorter = newSorter();
+    long[] record = new long[100];
+    int recordSize = record.length * 8;
+    int n = (int) pageSizeBytes / recordSize * 3;
+    for (int i = 0; i < n; i++) {
+      record[0] = (long) i;
+      sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+    }
+    assert(sorter.getNumberOfAllocatedPages() >= 2);
+    UnsafeExternalSorter.SpillableIterator iter =
+      (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
+    int lastv = 0;
+    for (int i = 0; i < n / 3; i++) {
+      iter.hasNext();
+      iter.loadNext();
+      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+      lastv = i;
+    }
+    assert(iter.spill() > 0);
+    assert(iter.spill() == 0);
+    assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == lastv);
+    for (int i = n / 3; i < n; i++) {
+      iter.hasNext();
+      iter.loadNext();
+      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+    }
+    sorter.cleanupResources();
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void forcedSpillingWithNotReadIterator() throws Exception {
+    final UnsafeExternalSorter sorter = newSorter();
+    long[] record = new long[100];
+    int recordSize = record.length * 8;
+    int n = (int) pageSizeBytes / recordSize * 3;
+    for (int i = 0; i < n; i++) {
+      record[0] = (long) i;
+      sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+    }
+    assert(sorter.getNumberOfAllocatedPages() >= 2);
+    UnsafeExternalSorter.SpillableIterator iter =
+      (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
+    assert(iter.spill() > 0);
+    assert(iter.spill() == 0);
+    for (int i = 0; i < n; i++) {
+      iter.hasNext();
+      iter.loadNext();
+      assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+    }
+    sorter.cleanupResources();
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
   public void testPeakMemoryUsed() throws Exception {
     final long recordLengthBytes = 8;
     final long pageSizeBytes = 256;
@@ -334,7 +391,7 @@ public class UnsafeExternalSorterSuite {
         insertNumber(sorter, i);
         newPeakMemory = sorter.getPeakMemoryUsedBytes();
         // The first page is pre-allocated on instantiation
-        if (i % numRecordsPerPage == 0 && i > 0) {
+        if (i % numRecordsPerPage == 0) {
           // We allocated a new page for this record, so peak memory should change
           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
         } else {
@@ -358,21 +415,5 @@ public class UnsafeExternalSorterSuite {
     }
   }
 
-  @Test
-  public void testReservePageOnInstantiation() throws Exception {
-    final UnsafeExternalSorter sorter = newSorter();
-    try {
-      assertEquals(1, sorter.getNumberOfAllocatedPages());
-      // Inserting a new record doesn't allocate more memory since we already have a page
-      long peakMemory = sorter.getPeakMemoryUsedBytes();
-      insertNumber(sorter, 100);
-      assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes());
-      assertEquals(1, sorter.getNumberOfAllocatedPages());
-    } finally {
-      sorter.cleanupResources();
-      assertSpillFilesWereCleanedUp();
-    }
-  }
-
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index d5de56a..642f658 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -20,17 +20,19 @@ package org.apache.spark.util.collection.unsafe.sort;
 import java.util.Arrays;
 
 import org.junit.Test;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.*;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.mock;
 
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.isIn;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
 
 public class UnsafeInMemorySorterSuite {
 
@@ -44,7 +46,7 @@ public class UnsafeInMemorySorterSuite {
   public void testSortingEmptyInput() {
     final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
       new TaskMemoryManager(
-        new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
+        new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
       mock(RecordComparator.class),
       mock(PrefixComparator.class),
       100);
@@ -66,8 +68,8 @@ public class UnsafeInMemorySorterSuite {
       "Mango"
     };
     final TaskMemoryManager memoryManager = new TaskMemoryManager(
-      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
-    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
     final Object baseObject = dataPage.getBaseObject();
     // Write the records into the data page:
     long position = dataPage.getBaseOffset();

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/scala/org/apache/spark/FailureSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 0242cbc..203dab9 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // cause is preserved
     val thrownDueToTaskFailure = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocatePage(128)
+        TaskContext.get().taskMemoryManager().allocatePage(128, null)
         throw new Exception("intentional task failure")
         iter
       }.count()
@@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // If the task succeeded but memory was leaked, then the task should fail due to that leak
     val thrownDueToMemoryLeak = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocatePage(128)
+        TaskContext.get().taskMemoryManager().allocatePage(128, null)
         iter
       }.count()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
deleted file mode 100644
index fe102d8..0000000
--- a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.memory
-
-import scala.collection.mutable
-
-import org.apache.spark.SparkConf
-import org.apache.spark.storage.{BlockStatus, BlockId}
-
-class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
-  private[memory] override def doAcquireExecutionMemory(
-      numBytes: Long,
-      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
-    if (oom) {
-      oom = false
-      0
-    } else {
-      _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory
-      numBytes
-    }
-  }
-  override def acquireStorageMemory(
-      blockId: BlockId,
-      numBytes: Long,
-      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
-  override def acquireUnrollMemory(
-      blockId: BlockId,
-      numBytes: Long,
-      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
-  override def releaseStorageMemory(numBytes: Long): Unit = { }
-  override def maxExecutionMemory: Long = Long.MaxValue
-  override def maxStorageMemory: Long = Long.MaxValue
-
-  private var oom = false
-
-  def markExecutionAsOutOfMemory(): Unit = {
-    oom = true
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
index 1265087..4a9479c 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
@@ -145,20 +145,20 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val manager = createMemoryManager(1000L)
     val taskMemoryManager = new TaskMemoryManager(manager, 0)
 
-    assert(taskMemoryManager.acquireExecutionMemory(100L) === 100L)
-    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
-    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
-    assert(taskMemoryManager.acquireExecutionMemory(200L) === 100L)
-    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
-    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(200L, null) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
 
-    taskMemoryManager.releaseExecutionMemory(500L)
-    assert(taskMemoryManager.acquireExecutionMemory(300L) === 300L)
-    assert(taskMemoryManager.acquireExecutionMemory(300L) === 200L)
+    taskMemoryManager.releaseExecutionMemory(500L, null)
+    assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 300L)
+    assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 200L)
 
     taskMemoryManager.cleanUpAllAllocatedMemory()
-    assert(taskMemoryManager.acquireExecutionMemory(1000L) === 1000L)
-    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(1000L, null) === 1000L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
   }
 
   test("two tasks requesting full execution memory") {
@@ -168,15 +168,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val futureTimeout: Duration = 20.seconds
 
     // Have both tasks request 500 bytes, then wait until both requests have been granted:
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L) }
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 500L)
     assert(Await.result(t2Result1, futureTimeout) === 500L)
 
     // Have both tasks each request 500 bytes more; both should immediately return 0 as they are
     // both now at 1 / N
-    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t1Result2, 200.millis) === 0L)
     assert(Await.result(t2Result2, 200.millis) === 0L)
   }
@@ -188,15 +188,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val futureTimeout: Duration = 20.seconds
 
     // Have both tasks request 250 bytes, then wait until both requests have been granted:
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L) }
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, null) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 250L)
     assert(Await.result(t2Result1, futureTimeout) === 250L)
 
     // Have both tasks each request 500 bytes more.
     // We should only grant 250 bytes to each of them on this second request
-    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t1Result2, futureTimeout) === 250L)
     assert(Await.result(t2Result2, futureTimeout) === 250L)
   }
@@ -208,17 +208,17 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val futureTimeout: Duration = 20.seconds
 
     // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 1000L)
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) }
     // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
     // to make sure the other thread blocks for some time otherwise.
     Thread.sleep(300)
-    t1MemManager.releaseExecutionMemory(250L)
+    t1MemManager.releaseExecutionMemory(250L, null)
     // The memory freed from t1 should now be granted to t2.
     assert(Await.result(t2Result1, futureTimeout) === 250L)
     // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory.
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, null) }
     assert(Await.result(t2Result2, 200.millis) === 0L)
   }
 
@@ -229,18 +229,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val futureTimeout: Duration = 20.seconds
 
     // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 1000L)
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
     // to make sure the other thread blocks for some time otherwise.
     Thread.sleep(300)
     // t1 releases all of its memory, so t2 should be able to grab all of the memory
     t1MemManager.cleanUpAllAllocatedMemory()
     assert(Await.result(t2Result1, futureTimeout) === 500L)
-    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t2Result2, futureTimeout) === 500L)
-    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
     assert(Await.result(t2Result3, 200.millis) === 0L)
   }
 
@@ -251,13 +251,13 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     val t2MemManager = new TaskMemoryManager(memoryManager, 2)
     val futureTimeout: Duration = 20.seconds
 
-    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L) }
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, null) }
     assert(Await.result(t1Result1, futureTimeout) === 700L)
 
-    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, null) }
     assert(Await.result(t2Result1, futureTimeout) === 300L)
 
-    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L) }
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, null) }
     assert(Await.result(t1Result2, 200.millis) === 0L)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
new file mode 100644
index 0000000..77e4355
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.{BlockStatus, BlockId}
+
+class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
+  private[memory] override def doAcquireExecutionMemory(
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
+    if (oomOnce) {
+      oomOnce = false
+      0
+    } else if (available >= numBytes) {
+      _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory
+      available -= numBytes
+      numBytes
+    } else {
+      _executionMemoryUsed += available
+      val grant = available
+      available = 0
+      grant
+    }
+  }
+  override def acquireStorageMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
+  override def acquireUnrollMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
+  override def releaseExecutionMemory(numBytes: Long): Unit = {
+    available += numBytes
+    _executionMemoryUsed -= numBytes
+  }
+  override def releaseStorageMemory(numBytes: Long): Unit = {}
+  override def maxExecutionMemory: Long = Long.MaxValue
+  override def maxStorageMemory: Long = Long.MaxValue
+
+  private var oomOnce = false
+  private var available = Long.MaxValue
+
+  def markExecutionAsOutOfMemoryOnce(): Unit = {
+    oomOnce = true
+  }
+
+  def limit(avail: Long): Unit = {
+    available = avail
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 810c74f..f7063d1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -96,15 +96,10 @@ final class UnsafeExternalRowSorter {
     );
     numRowsInserted++;
     if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
-      spill();
+      sorter.spill();
     }
   }
 
-  @VisibleForTesting
-  void spill() throws IOException {
-    sorter.spill();
-  }
-
   /**
    * Return the peak memory used so far, in bytes.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 82c645d..889f970 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -165,7 +165,7 @@ public final class UnsafeFixedWidthAggregationMap {
   public KVIterator<UnsafeRow, UnsafeRow> iterator() {
     return new KVIterator<UnsafeRow, UnsafeRow>() {
 
-      private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator =
+      private final BytesToBytesMap.MapIterator mapLocationIterator =
         map.destructiveIterator();
       private final UnsafeRow key = new UnsafeRow();
       private final UnsafeRow value = new UnsafeRow();

http://git-wip-us.apache.org/repos/asf/spark/blob/56419cf1/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 46301f0..845f2ae 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -17,13 +17,13 @@
 
 package org.apache.spark.sql.execution;
 
-import java.io.IOException;
-
 import javax.annotation.Nullable;
+import java.io.IOException;
 
 import com.google.common.annotations.VisibleForTesting;
 
 import org.apache.spark.TaskContext;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
@@ -33,7 +33,6 @@ import org.apache.spark.unsafe.KVIterator;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.map.BytesToBytesMap;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.collection.unsafe.sort.*;
 
 /**
@@ -84,18 +83,16 @@ public final class UnsafeKVExternalSorter {
         /* initialSize */ 4096,
         pageSizeBytes);
     } else {
-      // Insert the records into the in-memory sorter.
-      // We will use the number of elements in the map as the initialSize of the
-      // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize,
-      // we will use 1 as its initial size if the map is empty.
-      // TODO: track pointer array memory used by this in-memory sorter! (SPARK-10474)
+      // The memory needed for UnsafeInMemorySorter should be less than longArray in map.
+      map.freeArray();
+      // The memory used by UnsafeInMemorySorter will be counted later (end of this block)
       final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
         taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements()));
 
       // We cannot use the destructive iterator here because we are reusing the existing memory
       // pages in BytesToBytesMap to hold records during sorting.
       // The only new memory we are allocating is the pointer/prefix array.
-      BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+      BytesToBytesMap.MapIterator iter = map.iterator();
       final int numKeyFields = keySchema.size();
       UnsafeRow row = new UnsafeRow();
       while (iter.hasNext()) {
@@ -117,7 +114,7 @@ public final class UnsafeKVExternalSorter {
       }
 
       sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
-        taskContext.taskMemoryManager(),
+        taskMemoryManager,
         blockManager,
         taskContext,
         new KVComparator(ordering, keySchema.length()),
@@ -128,6 +125,8 @@ public final class UnsafeKVExternalSorter {
 
       sorter.spill();
       map.free();
+      // counting the memory used UnsafeInMemorySorter
+      taskMemoryManager.acquireExecutionMemory(inMemSorter.getMemoryUsage(), sorter);
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org