You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/10/26 05:20:03 UTC

[1/3] spark git commit: [SPARK-10984] Simplify *MemoryManager class structure

Repository: spark
Updated Branches:
  refs/heads/master 63accc796 -> 85e654c5e


http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index e2cb791..d7b2d07 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.util.collection
 
+import org.apache.spark.memory.MemoryTestingUtils
+
 import scala.collection.mutable.ArrayBuffer
 import scala.util.Random
 
@@ -98,6 +100,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     val conf = createSparkConf(loadDefaults = true, kryo = false)
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
 
     def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i)
     def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i
@@ -109,7 +112,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
       createCombiner _, mergeValue _, mergeCombiners _)
 
     val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
-      Some(agg), None, None, None)
+      context, Some(agg), None, None, None)
 
     val collisionPairs = Seq(
       ("Aa", "BB"),                   // 2112
@@ -158,8 +161,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     val conf = createSparkConf(loadDefaults = true, kryo = false)
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
     val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
-    val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None)
+    val sorter = new ExternalSorter[FixedHashObject, Int, Int](context, Some(agg), None, None, None)
     // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
     // problems if the map fails to group together the objects with the same code (SPARK-2043).
     val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1)
@@ -180,6 +184,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     val conf = createSparkConf(loadDefaults = true, kryo = false)
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
 
     def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i)
     def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i
@@ -188,7 +193,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     }
 
     val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
-    val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)
+    val sorter =
+      new ExternalSorter[Int, Int, ArrayBuffer[Int]](context, Some(agg), None, None, None)
     sorter.insertAll(
       (1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
     assert(sorter.numSpills > 0, "sorter did not spill")
@@ -204,6 +210,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     val conf = createSparkConf(loadDefaults = true, kryo = false)
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
 
     def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i)
     def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i
@@ -214,7 +221,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
       createCombiner, mergeValue, mergeCombiners)
 
     val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
-      Some(agg), None, None, None)
+      context, Some(agg), None, None, None)
 
     sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
       (null.asInstanceOf[String], "1"),
@@ -271,31 +278,32 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
   private def emptyDataStream(conf: SparkConf) {
     conf.set("spark.shuffle.manager", "sort")
     sc = new SparkContext("local", "test", conf)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
 
     val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
     val ord = implicitly[Ordering[Int]]
 
     // Both aggregator and ordering
     val sorter = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+      context, Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
     assert(sorter.iterator.toSeq === Seq())
     sorter.stop()
 
     // Only aggregator
     val sorter2 = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(3)), None, None)
+      context, Some(agg), Some(new HashPartitioner(3)), None, None)
     assert(sorter2.iterator.toSeq === Seq())
     sorter2.stop()
 
     // Only ordering
     val sorter3 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), Some(ord), None)
+      context, None, Some(new HashPartitioner(3)), Some(ord), None)
     assert(sorter3.iterator.toSeq === Seq())
     sorter3.stop()
 
     // Neither aggregator nor ordering
     val sorter4 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), None, None)
+      context, None, Some(new HashPartitioner(3)), None, None)
     assert(sorter4.iterator.toSeq === Seq())
     sorter4.stop()
   }
@@ -303,6 +311,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
   private def fewElementsPerPartition(conf: SparkConf) {
     conf.set("spark.shuffle.manager", "sort")
     sc = new SparkContext("local", "test", conf)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
 
     val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
     val ord = implicitly[Ordering[Int]]
@@ -313,28 +322,28 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
 
     // Both aggregator and ordering
     val sorter = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
+      context, Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
     sorter.insertAll(elements.iterator)
     assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
     sorter.stop()
 
     // Only aggregator
     val sorter2 = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(7)), None, None)
+      context, Some(agg), Some(new HashPartitioner(7)), None, None)
     sorter2.insertAll(elements.iterator)
     assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
     sorter2.stop()
 
     // Only ordering
     val sorter3 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(7)), Some(ord), None)
+      context, None, Some(new HashPartitioner(7)), Some(ord), None)
     sorter3.insertAll(elements.iterator)
     assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
     sorter3.stop()
 
     // Neither aggregator nor ordering
     val sorter4 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(7)), None, None)
+      context, None, Some(new HashPartitioner(7)), None, None)
     sorter4.insertAll(elements.iterator)
     assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
     sorter4.stop()
@@ -345,12 +354,13 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     conf.set("spark.shuffle.manager", "sort")
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local", "test", conf)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
 
     val ord = implicitly[Ordering[Int]]
     val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2))
 
     val sorter = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(7)), Some(ord), None)
+      context, None, Some(new HashPartitioner(7)), Some(ord), None)
     sorter.insertAll(elements)
     assert(sorter.numSpills > 0, "sorter did not spill")
     val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
@@ -432,8 +442,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     val diskBlockManager = sc.env.blockManager.diskBlockManager
     val ord = implicitly[Ordering[Int]]
     val expectedSize = if (withFailures) size - 1 else size
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
     val sorter = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), Some(ord), None)
+      context, None, Some(new HashPartitioner(3)), Some(ord), None)
     if (withFailures) {
       intercept[SparkException] {
         sorter.insertAll((0 until size).iterator.map { i =>
@@ -501,7 +512,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
         None
       }
     val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None
-    val sorter = new ExternalSorter[Int, Int, Int](agg, Some(new HashPartitioner(3)), ord, None)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
+    val sorter =
+      new ExternalSorter[Int, Int, Int](context, agg, Some(new HashPartitioner(3)), ord, None)
     sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) })
     if (withSpilling) {
       assert(sorter.numSpills > 0, "sorter did not spill")
@@ -538,8 +551,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
 
     val testData = Array.tabulate(size) { _ => rand.nextInt().toString }
 
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
     val sorter1 = new ExternalSorter[String, String, String](
-      None, None, Some(wrongOrdering), None)
+      context, None, None, Some(wrongOrdering), None)
     val thrown = intercept[IllegalArgumentException] {
       sorter1.insertAll(testData.iterator.map(i => (i, i)))
       assert(sorter1.numSpills > 0, "sorter did not spill")
@@ -561,7 +575,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
       createCombiner, mergeValue, mergeCombiners)
 
     val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]](
-      Some(agg), None, None, None)
+      context, Some(agg), None, None, None)
     sorter2.insertAll(testData.iterator.map(i => (i, i)))
     assert(sorter2.numSpills > 0, "sorter did not spill")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 7d94e05..810c74f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -67,7 +67,6 @@ final class UnsafeExternalRowSorter {
     final TaskContext taskContext = TaskContext.get();
     sorter = UnsafeExternalSorter.create(
       taskContext.taskMemoryManager(),
-      sparkEnv.shuffleMemoryManager(),
       sparkEnv.blockManager(),
       taskContext,
       new RowComparator(ordering, schema.length()),

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 09511ff..82c645d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -22,7 +22,6 @@ import java.io.IOException;
 import com.google.common.annotations.VisibleForTesting;
 
 import org.apache.spark.SparkEnv;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -32,7 +31,7 @@ import org.apache.spark.unsafe.KVIterator;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.map.BytesToBytesMap;
 import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 
 /**
  * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
@@ -88,8 +87,6 @@ public final class UnsafeFixedWidthAggregationMap {
    * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
    * @param groupingKeySchema the schema of the grouping key, used for row conversion.
    * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
-   * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with
-   *                             other tasks.
    * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
    * @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
    * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
@@ -99,15 +96,14 @@ public final class UnsafeFixedWidthAggregationMap {
       StructType aggregationBufferSchema,
       StructType groupingKeySchema,
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       int initialCapacity,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
     this.aggregationBufferSchema = aggregationBufferSchema;
     this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
     this.groupingKeySchema = groupingKeySchema;
-    this.map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
+    this.map =
+      new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
     this.enablePerfMetrics = enablePerfMetrics;
 
     // Initialize the buffer for aggregation value
@@ -256,7 +252,7 @@ public final class UnsafeFixedWidthAggregationMap {
   public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
     UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter(
       groupingKeySchema, aggregationBufferSchema,
-      SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map);
+      SparkEnv.get().blockManager(), map.getPageSizeBytes(), map);
     return sorter;
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 9df5780..46301f0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -24,7 +24,6 @@ import javax.annotation.Nullable;
 import com.google.common.annotations.VisibleForTesting;
 
 import org.apache.spark.TaskContext;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
 import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
@@ -34,7 +33,7 @@ import org.apache.spark.unsafe.KVIterator;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.map.BytesToBytesMap;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.collection.unsafe.sort.*;
 
 /**
@@ -50,14 +49,19 @@ public final class UnsafeKVExternalSorter {
   private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
   private final UnsafeExternalSorter sorter;
 
-  public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
-      BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes)
-    throws IOException {
-    this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null);
+  public UnsafeKVExternalSorter(
+      StructType keySchema,
+      StructType valueSchema,
+      BlockManager blockManager,
+      long pageSizeBytes) throws IOException {
+    this(keySchema, valueSchema, blockManager, pageSizeBytes, null);
   }
 
-  public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
-      BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes,
+  public UnsafeKVExternalSorter(
+      StructType keySchema,
+      StructType valueSchema,
+      BlockManager blockManager,
+      long pageSizeBytes,
       @Nullable BytesToBytesMap map) throws IOException {
     this.keySchema = keySchema;
     this.valueSchema = valueSchema;
@@ -73,7 +77,6 @@ public final class UnsafeKVExternalSorter {
     if (map == null) {
       sorter = UnsafeExternalSorter.create(
         taskMemoryManager,
-        shuffleMemoryManager,
         blockManager,
         taskContext,
         recordComparator,
@@ -115,7 +118,6 @@ public final class UnsafeKVExternalSorter {
 
       sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
         taskContext.taskMemoryManager(),
-        shuffleMemoryManager,
         blockManager,
         taskContext,
         new KVComparator(ordering, keySchema.length()),

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 7cd0f7b..fb2fc98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.unsafe.KVIterator
-import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
+import org.apache.spark.{InternalAccumulator, Logging, TaskContext}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType
  *
  * This iterator first uses hash-based aggregation to process input rows. It uses
  * a hash map to store groups and their corresponding aggregation buffers. If we
- * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]],
+ * this map cannot allocate memory from memory manager,
  * it switches to sort-based aggregation. The process of the switch has the following step:
  *  - Step 1: Sort all entries of the hash map based on values of grouping expressions and
  *            spill them to disk.
@@ -480,10 +480,9 @@ class TungstenAggregationIterator(
     initialAggregationBuffer,
     StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)),
     StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
-    TaskContext.get.taskMemoryManager(),
-    SparkEnv.get.shuffleMemoryManager,
+    TaskContext.get().taskMemoryManager(),
     1024 * 16, // initial capacity
-    SparkEnv.get.shuffleMemoryManager.pageSizeBytes,
+    TaskContext.get().taskMemoryManager().pageSizeBytes,
     false // disable tracking of performance metrics
   )
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index cfd64c1..1b59b19 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -344,8 +344,7 @@ private[sql] class DynamicPartitionWriterContainer(
               StructType.fromAttributes(partitionColumns),
               StructType.fromAttributes(dataColumns),
               SparkEnv.get.blockManager,
-              SparkEnv.get.shuffleMemoryManager,
-              SparkEnv.get.shuffleMemoryManager.pageSizeBytes)
+              TaskContext.get().taskMemoryManager().pageSizeBytes)
             sorter.insertKV(currentKey, getOutputRow(inputRow))
           }
         } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index bc255b2..cc8abb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -21,7 +21,7 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
 import java.nio.ByteOrder
 import java.util.{HashMap => JavaHashMap}
 
-import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.memory.{TaskMemoryManager, StaticMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.local.LocalNode
 import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.unsafe.memory.MemoryLocation
 import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.CompactBuffer
 import org.apache.spark.{SparkConf, SparkEnv}
@@ -320,21 +320,20 @@ private[joins] final class UnsafeHashedRelation(
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     val nKeys = in.readInt()
     // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
-    val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+    // TODO(josh): This needs to be revisited before we merge this patch; making this change now
+    // so that tests compile:
+    val taskMemoryManager = new TaskMemoryManager(
+      new StaticMemoryManager(
+        new SparkConf().set("spark.unsafe.offHeap", "false"), Long.MaxValue, Long.MaxValue, 1), 0)
 
-    val pageSizeBytes = Option(SparkEnv.get).map(_.shuffleMemoryManager.pageSizeBytes)
+    val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
       .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
 
-    // Dummy shuffle memory manager which always grants all memory allocation requests.
-    // We use this because it doesn't make sense count shared broadcast variables' memory usage
-    // towards individual tasks' quotas. In the future, we should devise a better way of handling
-    // this.
-    val shuffleMemoryManager =
-      ShuffleMemoryManager.create(maxMemory = Long.MaxValue, pageSizeBytes = pageSizeBytes)
+    // TODO(josh): We won't need this dummy memory manager after future refactorings; revisit
+    // during code review
 
     binaryMap = new BytesToBytesMap(
       taskMemoryManager,
-      shuffleMemoryManager,
       (nKeys * 1.5 + 1).toInt, // reduce hash collision
       pageSizeBytes)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 9385e57..dd92dda 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -49,7 +49,8 @@ case class Sort(
   protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
     child.execute().mapPartitions( { iterator =>
       val ordering = newOrdering(sortOrder, child.output)
-      val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering))
+      val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
+        TaskContext.get(), ordering = Some(ordering))
       sorter.insertAll(iterator.map(r => (r.copy(), null)))
       val baseIterator = sorter.iterator.map(_._1)
       val context = TaskContext.get()
@@ -124,7 +125,7 @@ case class TungstenSort(
         }
       }
 
-      val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes
+      val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
       val sorter = new UnsafeExternalRowSorter(
         schema, ordering, prefixComparator, prefixComputer, pageSize)
       if (testSpillFrequency > 0) {

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
deleted file mode 100644
index c4358f4..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import scala.collection.mutable
-
-import org.apache.spark.memory.MemoryManager
-import org.apache.spark.shuffle.ShuffleMemoryManager
-import org.apache.spark.storage.{BlockId, BlockStatus}
-
-
-/**
- * A [[ShuffleMemoryManager]] that can be controlled to run out of memory.
- */
-class TestShuffleMemoryManager
-  extends ShuffleMemoryManager(new GrantEverythingMemoryManager, 4 * 1024 * 1024) {
-  private var oom = false
-
-  override def tryToAcquire(numBytes: Long): Long = {
-    if (oom) {
-      oom = false
-      0
-    } else {
-      // Uncomment the following to trace memory allocations.
-      // println(s"tryToAcquire $numBytes in " +
-      //   Thread.currentThread().getStackTrace.mkString("", "\n  -", ""))
-      val acquired = super.tryToAcquire(numBytes)
-      acquired
-    }
-  }
-
-  override def release(numBytes: Long): Unit = {
-    // Uncomment the following to trace memory releases.
-    // println(s"release $numBytes in " +
-    //   Thread.currentThread().getStackTrace.mkString("", "\n  -", ""))
-    super.release(numBytes)
-  }
-
-  def markAsOutOfMemory(): Unit = {
-    oom = true
-  }
-}
-
-private class GrantEverythingMemoryManager extends MemoryManager {
-  override def acquireExecutionMemory(
-      numBytes: Long,
-      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = numBytes
-  override def acquireStorageMemory(
-      blockId: BlockId,
-      numBytes: Long,
-      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
-  override def acquireUnrollMemory(
-      blockId: BlockId,
-      numBytes: Long,
-      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
-  override def releaseExecutionMemory(numBytes: Long): Unit = { }
-  override def releaseStorageMemory(numBytes: Long): Unit = { }
-  override def maxExecutionMemory: Long = Long.MaxValue
-  override def maxStorageMemory: Long = Long.MaxValue
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 1739798..dbf4863 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -23,13 +23,12 @@ import scala.util.{Try, Random}
 
 import org.scalatest.Matchers
 
-import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
-import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite}
+import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
@@ -49,23 +48,22 @@ class UnsafeFixedWidthAggregationMapSuite
   private def emptyAggregationBuffer: InternalRow = InternalRow(0)
   private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
 
+  private var memoryManager: GrantEverythingMemoryManager = null
   private var taskMemoryManager: TaskMemoryManager = null
-  private var shuffleMemoryManager: TestShuffleMemoryManager = null
 
   def testWithMemoryLeakDetection(name: String)(f: => Unit) {
     def cleanup(): Unit = {
       if (taskMemoryManager != null) {
-        val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask()
         assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
-        assert(leakedShuffleMemory === 0)
         taskMemoryManager = null
       }
       TaskContext.unset()
     }
 
     test(name) {
-      taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
-      shuffleMemoryManager = new TestShuffleMemoryManager
+      val conf = new SparkConf().set("spark.unsafe.offHeap", "false")
+      memoryManager = new GrantEverythingMemoryManager(conf)
+      taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
 
       TaskContext.setTaskContext(new TaskContextImpl(
         stageId = 0,
@@ -110,7 +108,6 @@ class UnsafeFixedWidthAggregationMapSuite
       aggBufferSchema,
       groupKeySchema,
       taskMemoryManager,
-      shuffleMemoryManager,
       1024, // initial capacity,
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -125,7 +122,6 @@ class UnsafeFixedWidthAggregationMapSuite
       aggBufferSchema,
       groupKeySchema,
       taskMemoryManager,
-      shuffleMemoryManager,
       1024, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -153,7 +149,6 @@ class UnsafeFixedWidthAggregationMapSuite
       aggBufferSchema,
       groupKeySchema,
       taskMemoryManager,
-      shuffleMemoryManager,
       128, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -176,14 +171,13 @@ class UnsafeFixedWidthAggregationMapSuite
 
   testWithMemoryLeakDetection("test external sorting") {
     // Memory consumption in the beginning of the task.
-    val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
+    val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask()
 
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
       taskMemoryManager,
-      shuffleMemoryManager,
       128, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -200,7 +194,7 @@ class UnsafeFixedWidthAggregationMapSuite
     val sorter = map.destructAndCreateExternalSorter()
 
     withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
-      assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
+      assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
     }
 
     // Add more keys to the sorter and make sure the results come out sorted.
@@ -214,7 +208,7 @@ class UnsafeFixedWidthAggregationMapSuite
       sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
 
       if ((i % 100) == 0) {
-        shuffleMemoryManager.markAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemory()
         sorter.closeCurrentPage()
       }
     }
@@ -238,7 +232,6 @@ class UnsafeFixedWidthAggregationMapSuite
       aggBufferSchema,
       groupKeySchema,
       taskMemoryManager,
-      shuffleMemoryManager,
       128, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -258,7 +251,7 @@ class UnsafeFixedWidthAggregationMapSuite
       sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
 
       if ((i % 100) == 0) {
-        shuffleMemoryManager.markAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemory()
         sorter.closeCurrentPage()
       }
     }
@@ -281,14 +274,13 @@ class UnsafeFixedWidthAggregationMapSuite
   testWithMemoryLeakDetection("test external sorting with empty records") {
 
     // Memory consumption in the beginning of the task.
-    val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
+    val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask()
 
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
       StructType(Nil),
       StructType(Nil),
       taskMemoryManager,
-      shuffleMemoryManager,
       128, // initial capacity
       PAGE_SIZE_BYTES,
       false // disable perf metrics
@@ -303,7 +295,7 @@ class UnsafeFixedWidthAggregationMapSuite
     val sorter = map.destructAndCreateExternalSorter()
 
     withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
-      assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
+      assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
     }
 
     // Add more keys to the sorter and make sure the results come out sorted.
@@ -311,7 +303,7 @@ class UnsafeFixedWidthAggregationMapSuite
       sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0))
 
       if ((i % 100) == 0) {
-        shuffleMemoryManager.markAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemory()
         sorter.closeCurrentPage()
       }
     }
@@ -332,34 +324,28 @@ class UnsafeFixedWidthAggregationMapSuite
   }
 
   testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") {
-    val smm = ShuffleMemoryManager.createForTesting(65536)
     val pageSize = 4096
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
       taskMemoryManager,
-      smm,
       128, // initial capacity
       pageSize,
       false // disable perf metrics
     )
 
-    // Insert into the map until we've run out of space
     val rand = new Random(42)
-    var hasSpace = true
-    while (hasSpace) {
+    for (i <- 1 to 100) {
       val str = rand.nextString(1024)
       val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
-      if (buf == null) {
-        hasSpace = false
-      } else {
-        buf.setInt(0, str.length)
-      }
+      buf.setInt(0, str.length)
     }
-
-    // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte
-    assert(smm.tryToAcquire(1) === 0)
+    // Simulate running out of space
+    memoryManager.markExecutionAsOutOfMemory()
+    val str = rand.nextString(1024)
+    val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
+    assert(buf == null)
 
     // Convert the map into a sorter. This used to fail before the fix for SPARK-10474
     // because we would try to acquire space for the in-memory sorter pointer array before

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index d3be568..13dc175 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.sql.execution
 import scala.util.Random
 
 import org.apache.spark._
+import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
 import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
 
 /**
  * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
@@ -108,9 +108,9 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
       inputData: Seq[(InternalRow, InternalRow)],
       pageSize: Long,
       spill: Boolean): Unit = {
-
-    val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
-    val shuffleMemMgr = new TestShuffleMemoryManager
+    val memoryManager =
+      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
+    val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
     TaskContext.setTaskContext(new TaskContextImpl(
       stageId = 0,
       partitionId = 0,
@@ -121,14 +121,14 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
       internalAccumulators = Seq.empty))
 
     val sorter = new UnsafeKVExternalSorter(
-      keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, pageSize)
+      keySchema, valueSchema, SparkEnv.get.blockManager, pageSize)
 
     // Insert the keys and values into the sorter
     inputData.foreach { case (k, v) =>
       sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
       // 1% chance we will spill
       if (rand.nextDouble() < 0.01 && spill) {
-        shuffleMemMgr.markAsOutOfMemory()
+        memoryManager.markExecutionAsOutOfMemory()
         sorter.closeCurrentPage()
       }
     }
@@ -170,12 +170,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
     assert(out.sorted(kvOrdering) === inputData.sorted(kvOrdering))
 
     // Make sure there is no memory leak
-    val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory
-    if (shuffleMemMgr != null) {
-      val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask()
-      assert(0L === leakedShuffleMemory)
-    }
-    assert(0 === leakedUnsafeMemory)
+    assert(0 === taskMemMgr.cleanUpAllAllocatedMemory)
     TaskContext.unset()
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 1680d7e..d32572b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream}
 
 import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.ShuffleBlockId
 import org.apache.spark.util.collection.ExternalSorter
@@ -112,7 +113,12 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
       val data = (1 to 10000).iterator.map { i =>
         (i, converter(Row(i)))
       }
+      val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
+      val taskContext = new TaskContextImpl(
+        0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc))
+
       val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
+        taskContext,
         partitioner = Some(new HashPartitioner(10)),
         serializer = Some(new UnsafeRowSerializer(numFields = 1)))
 
@@ -122,10 +128,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
       assert(sorter.numSpills > 0)
 
       // Merging spilled files should not throw assertion error
-      val taskContext =
-        new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc))
       taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics)
-      sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile)
+      sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
     } {
       // Clean up
       if (sc != null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index cc0ac1b..475037b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -18,16 +18,16 @@
 package org.apache.spark.sql.execution.aggregate
 
 import org.apache.spark._
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.unsafe.memory.TaskMemoryManager
 
 class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
 
   test("memory acquired on construction") {
-    val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
+    val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.memoryManager, 0)
     val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
     TaskContext.setTaskContext(taskContext)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index b2b6848..c17fb72 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -254,7 +254,7 @@ class ReceivedBlockHandlerSuite
       maxMem: Long,
       conf: SparkConf,
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
-    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
     val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf,
       memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java
deleted file mode 100644
index cbbe859..0000000
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.unsafe.memory;
-
-import java.lang.ref.WeakReference;
-import java.util.HashMap;
-import java.util.LinkedList;
-import java.util.Map;
-import javax.annotation.concurrent.GuardedBy;
-
-/**
- * Manages memory for an executor. Individual operators / tasks allocate memory through
- * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager.
- */
-public class ExecutorMemoryManager {
-
-  /**
-   * Allocator, exposed for enabling untracked allocations of temporary data structures.
-   */
-  public final MemoryAllocator allocator;
-
-  /**
-   * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe.
-   */
-  final boolean inHeap;
-
-  @GuardedBy("this")
-  private final Map<Long, LinkedList<WeakReference<MemoryBlock>>> bufferPoolsBySize =
-    new HashMap<Long, LinkedList<WeakReference<MemoryBlock>>>();
-
-  private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;
-
-  /**
-   * Construct a new ExecutorMemoryManager.
-   *
-   * @param allocator the allocator that will be used
-   */
-  public ExecutorMemoryManager(MemoryAllocator allocator) {
-    this.inHeap = allocator instanceof HeapMemoryAllocator;
-    this.allocator = allocator;
-  }
-
-  /**
-   * Returns true if allocations of the given size should go through the pooling mechanism and
-   * false otherwise.
-   */
-  private boolean shouldPool(long size) {
-    // Very small allocations are less likely to benefit from pooling.
-    // At some point, we should explore supporting pooling for off-heap memory, but for now we'll
-    // ignore that case in the interest of simplicity.
-    return size >= POOLING_THRESHOLD_BYTES && allocator instanceof HeapMemoryAllocator;
-  }
-
-  /**
-   * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
-   * to be zeroed out (call `zero()` on the result if this is necessary).
-   */
-  MemoryBlock allocate(long size) throws OutOfMemoryError {
-    if (shouldPool(size)) {
-      synchronized (this) {
-        final LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
-        if (pool != null) {
-          while (!pool.isEmpty()) {
-            final WeakReference<MemoryBlock> blockReference = pool.pop();
-            final MemoryBlock memory = blockReference.get();
-            if (memory != null) {
-              assert (memory.size() == size);
-              return memory;
-            }
-          }
-          bufferPoolsBySize.remove(size);
-        }
-      }
-      return allocator.allocate(size);
-    } else {
-      return allocator.allocate(size);
-    }
-  }
-
-  void free(MemoryBlock memory) {
-    final long size = memory.size();
-    if (shouldPool(size)) {
-      synchronized (this) {
-        LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
-        if (pool == null) {
-          pool = new LinkedList<WeakReference<MemoryBlock>>();
-          bufferPoolsBySize.put(size, pool);
-        }
-        pool.add(new WeakReference<MemoryBlock>(memory));
-      }
-    } else {
-      allocator.free(memory);
-    }
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index 6722301..ebe90d9 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -17,22 +17,71 @@
 
 package org.apache.spark.unsafe.memory;
 
+import javax.annotation.concurrent.GuardedBy;
+import java.lang.ref.WeakReference;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.Map;
+
 /**
  * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array.
  */
 public class HeapMemoryAllocator implements MemoryAllocator {
 
+  @GuardedBy("this")
+  private final Map<Long, LinkedList<WeakReference<MemoryBlock>>> bufferPoolsBySize =
+    new HashMap<>();
+
+  private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;
+
+  /**
+   * Returns true if allocations of the given size should go through the pooling mechanism and
+   * false otherwise.
+   */
+  private boolean shouldPool(long size) {
+    // Very small allocations are less likely to benefit from pooling.
+    return size >= POOLING_THRESHOLD_BYTES;
+  }
+
   @Override
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
     if (size % 8 != 0) {
       throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
     }
+    if (shouldPool(size)) {
+      synchronized (this) {
+        final LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
+        if (pool != null) {
+          while (!pool.isEmpty()) {
+            final WeakReference<MemoryBlock> blockReference = pool.pop();
+            final MemoryBlock memory = blockReference.get();
+            if (memory != null) {
+              assert (memory.size() == size);
+              return memory;
+            }
+          }
+          bufferPoolsBySize.remove(size);
+        }
+      }
+    }
     long[] array = new long[(int) (size / 8)];
     return MemoryBlock.fromLongArray(array);
   }
 
   @Override
   public void free(MemoryBlock memory) {
-    // Do nothing
+    final long size = memory.size();
+    if (shouldPool(size)) {
+      synchronized (this) {
+        LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
+        if (pool == null) {
+          pool = new LinkedList<>();
+          bufferPoolsBySize.put(size, pool);
+        }
+        pool.add(new WeakReference<>(memory));
+      }
+    } else {
+      // Do nothing
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index dd75820..e3e7947 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -30,9 +30,10 @@ public class MemoryBlock extends MemoryLocation {
 
   /**
    * Optional page number; used when this MemoryBlock represents a page allocated by a
-   * MemoryManager. This is package-private and is modified by MemoryManager.
+   * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager,
+   * which lives in a different package.
    */
-  int pageNumber = -1;
+  public int pageNumber = -1;
 
   public MemoryBlock(@Nullable Object obj, long offset, long length) {
     super(obj, offset);

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
deleted file mode 100644
index 97b2c93..0000000
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ /dev/null
@@ -1,286 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.unsafe.memory;
-
-import java.util.*;
-
-import com.google.common.annotations.VisibleForTesting;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * Manages the memory allocated by an individual task.
- * <p>
- * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs.
- * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is
- * addressed by the combination of a base Object reference and a 64-bit offset within that object.
- * This is a problem when we want to store pointers to data structures inside of other structures,
- * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits
- * to address memory, we can't just store the address of the base object since it's not guaranteed
- * to remain stable as the heap gets reorganized due to GC.
- * <p>
- * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap
- * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to
- * store a "page number" and the lower 51 bits to store an offset within this page. These page
- * numbers are used to index into a "page table" array inside of the MemoryManager in order to
- * retrieve the base object.
- * <p>
- * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the
- * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is
- * approximately 35 terabytes of memory.
- */
-public class TaskMemoryManager {
-
-  private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
-
-  /** The number of bits used to address the page table. */
-  private static final int PAGE_NUMBER_BITS = 13;
-
-  /** The number of bits used to encode offsets in data pages. */
-  @VisibleForTesting
-  static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS;  // 51
-
-  /** The number of entries in the page table. */
-  private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
-
-  /**
-   * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is
-   * (1L &lt;&lt; OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page
-   * size is limited by the maximum amount of data that can be stored in a  long[] array, which is
-   * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes.
-   */
-  public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L;
-
-  /** Bit mask for the lower 51 bits of a long. */
-  private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
-
-  /** Bit mask for the upper 13 bits of a long */
-  private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
-
-  /**
-   * Similar to an operating system's page table, this array maps page numbers into base object
-   * pointers, allowing us to translate between the hashtable's internal 64-bit address
-   * representation and the baseObject+offset representation which we use to support both in- and
-   * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`.
-   * When using an in-heap allocator, the entries in this map will point to pages' base objects.
-   * Entries are added to this map as new data pages are allocated.
-   */
-  private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE];
-
-  /**
-   * Bitmap for tracking free pages.
-   */
-  private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
-
-  /**
-   * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean
-   * up leaked memory.
-   */
-  private final HashSet<MemoryBlock> allocatedNonPageMemory = new HashSet<MemoryBlock>();
-
-  private final ExecutorMemoryManager executorMemoryManager;
-
-  /**
-   * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods
-   * without doing any masking or lookups. Since this branching should be well-predicted by the JIT,
-   * this extra layer of indirection / abstraction hopefully shouldn't be too expensive.
-   */
-  private final boolean inHeap;
-
-  /**
-   * Construct a new MemoryManager.
-   */
-  public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) {
-    this.inHeap = executorMemoryManager.inHeap;
-    this.executorMemoryManager = executorMemoryManager;
-  }
-
-  /**
-   * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
-   * intended for allocating large blocks of memory that will be shared between operators.
-   */
-  public MemoryBlock allocatePage(long size) {
-    if (size > MAXIMUM_PAGE_SIZE_BYTES) {
-      throw new IllegalArgumentException(
-        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
-    }
-
-    final int pageNumber;
-    synchronized (this) {
-      pageNumber = allocatedPages.nextClearBit(0);
-      if (pageNumber >= PAGE_TABLE_SIZE) {
-        throw new IllegalStateException(
-          "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
-      }
-      allocatedPages.set(pageNumber);
-    }
-    final MemoryBlock page = executorMemoryManager.allocate(size);
-    page.pageNumber = pageNumber;
-    pageTable[pageNumber] = page;
-    if (logger.isTraceEnabled()) {
-      logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
-    }
-    return page;
-  }
-
-  /**
-   * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
-   */
-  public void freePage(MemoryBlock page) {
-    assert (page.pageNumber != -1) :
-      "Called freePage() on memory that wasn't allocated with allocatePage()";
-    assert(allocatedPages.get(page.pageNumber));
-    pageTable[page.pageNumber] = null;
-    synchronized (this) {
-      allocatedPages.clear(page.pageNumber);
-    }
-    if (logger.isTraceEnabled()) {
-      logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
-    }
-    // Cannot access a page once it's freed.
-    executorMemoryManager.free(page);
-  }
-
-  /**
-   * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
-   * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended
-   * to be used for allocating operators' internal data structures. For data pages that you want to
-   * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since
-   * that will enable intra-memory pointers (see
-   * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's
-   * top-level Javadoc for more details).
-   */
-  public MemoryBlock allocate(long size) throws OutOfMemoryError {
-    assert(size > 0) : "Size must be positive, but got " + size;
-    final MemoryBlock memory = executorMemoryManager.allocate(size);
-    synchronized(allocatedNonPageMemory) {
-      allocatedNonPageMemory.add(memory);
-    }
-    return memory;
-  }
-
-  /**
-   * Free memory allocated by {@link TaskMemoryManager#allocate(long)}.
-   */
-  public void free(MemoryBlock memory) {
-    assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()";
-    executorMemoryManager.free(memory);
-    synchronized(allocatedNonPageMemory) {
-      final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory);
-      assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!";
-    }
-  }
-
-  /**
-   * Given a memory page and offset within that page, encode this address into a 64-bit long.
-   * This address will remain valid as long as the corresponding page has not been freed.
-   *
-   * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}.
-   * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
-   *                     this should be the value that you would pass as the base offset into an
-   *                     UNSAFE call (e.g. page.baseOffset() + something).
-   * @return an encoded page address.
-   */
-  public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
-    if (!inHeap) {
-      // In off-heap mode, an offset is an absolute address that may require a full 64 bits to
-      // encode. Due to our page size limitation, though, we can convert this into an offset that's
-      // relative to the page's base offset; this relative offset will fit in 51 bits.
-      offsetInPage -= page.getBaseOffset();
-    }
-    return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
-  }
-
-  @VisibleForTesting
-  public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
-    assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
-    return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
-  }
-
-  @VisibleForTesting
-  public static int decodePageNumber(long pagePlusOffsetAddress) {
-    return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS);
-  }
-
-  private static long decodeOffset(long pagePlusOffsetAddress) {
-    return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
-  }
-
-  /**
-   * Get the page associated with an address encoded by
-   * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
-   */
-  public Object getPage(long pagePlusOffsetAddress) {
-    if (inHeap) {
-      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
-      assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
-      final MemoryBlock page = pageTable[pageNumber];
-      assert (page != null);
-      assert (page.getBaseObject() != null);
-      return page.getBaseObject();
-    } else {
-      return null;
-    }
-  }
-
-  /**
-   * Get the offset associated with an address encoded by
-   * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
-   */
-  public long getOffsetInPage(long pagePlusOffsetAddress) {
-    final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
-    if (inHeap) {
-      return offsetInPage;
-    } else {
-      // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we
-      // converted the absolute address into a relative address. Here, we invert that operation:
-      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
-      assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
-      final MemoryBlock page = pageTable[pageNumber];
-      assert (page != null);
-      return page.getBaseOffset() + offsetInPage;
-    }
-  }
-
-  /**
-   * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return
-   * value can be used to detect memory leaks.
-   */
-  public long cleanUpAllAllocatedMemory() {
-    long freedBytes = 0;
-    for (MemoryBlock page : pageTable) {
-      if (page != null) {
-        freedBytes += page.size();
-        freePage(page);
-      }
-    }
-
-    synchronized (allocatedNonPageMemory) {
-      final Iterator<MemoryBlock> iter = allocatedNonPageMemory.iterator();
-      while (iter.hasNext()) {
-        final MemoryBlock memory = iter.next();
-        freedBytes += memory.size();
-        // We don't call free() here because that calls Set.remove, which would lead to a
-        // ConcurrentModificationException here.
-        executorMemoryManager.free(memory);
-        iter.remove();
-      }
-    }
-    return freedBytes;
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
deleted file mode 100644
index 06fb081..0000000
--- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.unsafe.memory;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class TaskMemoryManagerSuite {
-
-  @Test
-  public void leakedNonPageMemoryIsDetected() {
-    final TaskMemoryManager manager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-    manager.allocate(1024);  // leak memory
-    Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory());
-  }
-
-  @Test
-  public void leakedPageMemoryIsDetected() {
-    final TaskMemoryManager manager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-    manager.allocatePage(4096);  // leak memory
-    Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
-  }
-
-  @Test
-  public void encodePageNumberAndOffsetOffHeap() {
-    final TaskMemoryManager manager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
-    final MemoryBlock dataPage = manager.allocatePage(256);
-    // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
-    // encode. This test exercises that corner-case:
-    final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
-    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset);
-    Assert.assertEquals(null, manager.getPage(encodedAddress));
-    Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress));
-  }
-
-  @Test
-  public void encodePageNumberAndOffsetOnHeap() {
-    final TaskMemoryManager manager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-    final MemoryBlock dataPage = manager.allocatePage(256);
-    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
-    Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
-    Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
-  }
-
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[3/3] spark git commit: [SPARK-10984] Simplify *MemoryManager class structure

Posted by jo...@apache.org.
[SPARK-10984] Simplify *MemoryManager class structure

This patch refactors the MemoryManager class structure. After #9000, Spark had the following classes:

- MemoryManager
- StaticMemoryManager
- ExecutorMemoryManager
- TaskMemoryManager
- ShuffleMemoryManager

This is fairly confusing. To simplify things, this patch consolidates several of these classes:

- ShuffleMemoryManager and ExecutorMemoryManager were merged into MemoryManager.
- TaskMemoryManager is moved into Spark Core.

**Key changes and tasks**:

- [x] Merge ExecutorMemoryManager into MemoryManager.
  - [x] Move pooling logic into Allocator.
- [x] Move TaskMemoryManager from `spark-unsafe` to `spark-core`.
- [x] Refactor the existing Tungsten TaskMemoryManager interactions so Tungsten code use only this and not both this and ShuffleMemoryManager.
- [x] Refactor non-Tungsten code to use the TaskMemoryManager instead of ShuffleMemoryManager.
- [x] Merge ShuffleMemoryManager into MemoryManager.
  - [x] Move code
  - [x] ~~Simplify 1/n calculation.~~ **Will defer to followup, since this needs more work.**
- [x] Port ShuffleMemoryManagerSuite tests.
- [x] Move classes from `unsafe` package to `memory` package.
- [ ] Figure out how to handle the hacky use of the memory managers in HashedRelation's broadcast variable construction.
- [x] Test porting and cleanup: several tests relied on mock functionality (such as `TestShuffleMemoryManager.markAsOutOfMemory`) which has been changed or broken during the memory manager consolidation
  - [x] AbstractBytesToBytesMapSuite
  - [x] UnsafeExternalSorterSuite
  - [x] UnsafeFixedWidthAggregationMapSuite
  - [x] UnsafeKVExternalSorterSuite

**Compatiblity notes**:

- This patch introduces breaking changes in `ExternalAppendOnlyMap`, which is marked as `DevloperAPI` (likely for legacy reasons): this class now cannot be used outside of a task.

Author: Josh Rosen <jo...@databricks.com>

Closes #9127 from JoshRosen/SPARK-10984.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/85e654c5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/85e654c5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/85e654c5

Branch: refs/heads/master
Commit: 85e654c5ec87e666a8845bfd77185c1ea57b268a
Parents: 63accc7
Author: Josh Rosen <jo...@databricks.com>
Authored: Sun Oct 25 21:19:52 2015 -0700
Committer: Josh Rosen <jo...@databricks.com>
Committed: Sun Oct 25 21:19:52 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/memory/TaskMemoryManager.java  | 283 ++++++++++++++++
 .../spark/shuffle/sort/PackedRecordPointer.java |   4 +-
 .../shuffle/sort/ShuffleExternalSorter.java     |  57 ++--
 .../spark/shuffle/sort/UnsafeShuffleWriter.java |   7 +-
 .../spark/unsafe/map/BytesToBytesMap.java       |  36 +-
 .../unsafe/sort/RecordPointerAndKeyPrefix.java  |   4 +-
 .../unsafe/sort/UnsafeExternalSorter.java       |  51 +--
 .../unsafe/sort/UnsafeInMemorySorter.java       |   2 +-
 .../main/scala/org/apache/spark/SparkEnv.scala  |  23 +-
 .../scala/org/apache/spark/TaskContext.scala    |   2 +-
 .../org/apache/spark/TaskContextImpl.scala      |   2 +-
 .../org/apache/spark/executor/Executor.scala    |   4 +-
 .../org/apache/spark/memory/MemoryManager.scala | 197 ++++++++++-
 .../spark/memory/StaticMemoryManager.scala      |  12 +-
 .../spark/memory/UnifiedMemoryManager.scala     |  12 +-
 .../scala/org/apache/spark/scheduler/Task.scala |   6 +-
 .../spark/shuffle/BlockStoreShuffleReader.scala |   5 +-
 .../spark/shuffle/ShuffleMemoryManager.scala    | 209 ------------
 .../spark/shuffle/sort/SortShuffleManager.scala |   1 -
 .../spark/shuffle/sort/SortShuffleWriter.scala  |   6 +-
 .../util/collection/ExternalAppendOnlyMap.scala |  49 ++-
 .../spark/util/collection/ExternalSorter.scala  |   8 +-
 .../spark/util/collection/Spillable.scala       |  16 +-
 .../spark/memory/TaskMemoryManagerSuite.java    |  59 ++++
 .../shuffle/sort/PackedRecordPointerSuite.java  |  12 +-
 .../sort/ShuffleInMemorySorterSuite.java        |   9 +-
 .../shuffle/sort/UnsafeShuffleWriterSuite.java  |  53 ++-
 .../map/AbstractBytesToBytesMapSuite.java       | 108 ++----
 .../unsafe/map/BytesToBytesMapOffHeapSuite.java |   7 +-
 .../unsafe/map/BytesToBytesMapOnHeapSuite.java  |   7 +-
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |  34 +-
 .../unsafe/sort/UnsafeInMemorySorterSuite.java  |  13 +-
 .../scala/org/apache/spark/FailureSuite.scala   |   4 +-
 .../memory/GrantEverythingMemoryManager.scala   |  54 +++
 .../spark/memory/MemoryManagerSuite.scala       | 134 ++++++++
 .../spark/memory/MemoryTestingUtils.scala       |  37 +++
 .../spark/memory/StaticMemoryManagerSuite.scala |  24 +-
 .../memory/UnifiedMemoryManagerSuite.scala      |  26 +-
 .../shuffle/ShuffleMemoryManagerSuite.scala     | 326 -------------------
 .../storage/BlockManagerReplicationSuite.scala  |   4 +-
 .../spark/storage/BlockManagerSuite.scala       |   8 +-
 .../collection/ExternalAppendOnlyMapSuite.scala |  60 ++--
 .../util/collection/ExternalSorterSuite.scala   |  48 ++-
 .../sql/execution/UnsafeExternalRowSorter.java  |   1 -
 .../UnsafeFixedWidthAggregationMap.java         |  12 +-
 .../sql/execution/UnsafeKVExternalSorter.java   |  22 +-
 .../aggregate/TungstenAggregationIterator.scala |   9 +-
 .../execution/datasources/WriterContainer.scala |   3 +-
 .../sql/execution/joins/HashedRelation.scala    |  21 +-
 .../org/apache/spark/sql/execution/sort.scala   |   5 +-
 .../execution/TestShuffleMemoryManager.scala    |  75 -----
 .../UnsafeFixedWidthAggregationMapSuite.scala   |  54 ++-
 .../execution/UnsafeKVExternalSorterSuite.scala |  19 +-
 .../execution/UnsafeRowSerializerSuite.scala    |  10 +-
 .../TungstenAggregationIteratorSuite.scala      |   4 +-
 .../streaming/ReceivedBlockHandlerSuite.scala   |   2 +-
 .../unsafe/memory/ExecutorMemoryManager.java    | 111 -------
 .../unsafe/memory/HeapMemoryAllocator.java      |  51 ++-
 .../apache/spark/unsafe/memory/MemoryBlock.java |   5 +-
 .../spark/unsafe/memory/TaskMemoryManager.java  | 286 ----------------
 .../unsafe/memory/TaskMemoryManagerSuite.java   |  64 ----
 61 files changed, 1205 insertions(+), 1572 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
new file mode 100644
index 0000000..7b31c90
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+import java.util.*;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+/**
+ * Manages the memory allocated by an individual task.
+ * <p>
+ * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs.
+ * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is
+ * addressed by the combination of a base Object reference and a 64-bit offset within that object.
+ * This is a problem when we want to store pointers to data structures inside of other structures,
+ * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits
+ * to address memory, we can't just store the address of the base object since it's not guaranteed
+ * to remain stable as the heap gets reorganized due to GC.
+ * <p>
+ * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap
+ * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to
+ * store a "page number" and the lower 51 bits to store an offset within this page. These page
+ * numbers are used to index into a "page table" array inside of the MemoryManager in order to
+ * retrieve the base object.
+ * <p>
+ * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the
+ * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is
+ * approximately 35 terabytes of memory.
+ */
+public class TaskMemoryManager {
+
+  private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
+
+  /** The number of bits used to address the page table. */
+  private static final int PAGE_NUMBER_BITS = 13;
+
+  /** The number of bits used to encode offsets in data pages. */
+  @VisibleForTesting
+  static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS;  // 51
+
+  /** The number of entries in the page table. */
+  private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
+
+  /**
+   * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is
+   * (1L &lt;&lt; OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page
+   * size is limited by the maximum amount of data that can be stored in a  long[] array, which is
+   * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes.
+   */
+  public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L;
+
+  /** Bit mask for the lower 51 bits of a long. */
+  private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
+
+  /** Bit mask for the upper 13 bits of a long */
+  private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+  /**
+   * Similar to an operating system's page table, this array maps page numbers into base object
+   * pointers, allowing us to translate between the hashtable's internal 64-bit address
+   * representation and the baseObject+offset representation which we use to support both in- and
+   * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`.
+   * When using an in-heap allocator, the entries in this map will point to pages' base objects.
+   * Entries are added to this map as new data pages are allocated.
+   */
+  private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE];
+
+  /**
+   * Bitmap for tracking free pages.
+   */
+  private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
+
+  private final MemoryManager memoryManager;
+
+  private final long taskAttemptId;
+
+  /**
+   * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods
+   * without doing any masking or lookups. Since this branching should be well-predicted by the JIT,
+   * this extra layer of indirection / abstraction hopefully shouldn't be too expensive.
+   */
+  private final boolean inHeap;
+
+  /**
+   * Construct a new TaskMemoryManager.
+   */
+  public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
+    this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap();
+    this.memoryManager = memoryManager;
+    this.taskAttemptId = taskAttemptId;
+  }
+
+  /**
+   * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
+   * @return number of bytes successfully granted (<= N).
+   */
+  public long acquireExecutionMemory(long size) {
+    return memoryManager.acquireExecutionMemory(size, taskAttemptId);
+  }
+
+  /**
+   * Release N bytes of execution memory.
+   */
+  public void releaseExecutionMemory(long size) {
+    memoryManager.releaseExecutionMemory(size, taskAttemptId);
+  }
+
+  public long pageSizeBytes() {
+    return memoryManager.pageSizeBytes();
+  }
+
+  /**
+   * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
+   * intended for allocating large blocks of Tungsten memory that will be shared between operators.
+   *
+   * Returns `null` if there was not enough memory to allocate the page.
+   */
+  public MemoryBlock allocatePage(long size) {
+    if (size > MAXIMUM_PAGE_SIZE_BYTES) {
+      throw new IllegalArgumentException(
+        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
+    }
+
+    final int pageNumber;
+    synchronized (this) {
+      pageNumber = allocatedPages.nextClearBit(0);
+      if (pageNumber >= PAGE_TABLE_SIZE) {
+        throw new IllegalStateException(
+          "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
+      }
+      allocatedPages.set(pageNumber);
+    }
+    final long acquiredExecutionMemory = acquireExecutionMemory(size);
+    if (acquiredExecutionMemory != size) {
+      releaseExecutionMemory(acquiredExecutionMemory);
+      synchronized (this) {
+        allocatedPages.clear(pageNumber);
+      }
+      return null;
+    }
+    final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size);
+    page.pageNumber = pageNumber;
+    pageTable[pageNumber] = page;
+    if (logger.isTraceEnabled()) {
+      logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
+    }
+    return page;
+  }
+
+  /**
+   * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
+   */
+  public void freePage(MemoryBlock page) {
+    assert (page.pageNumber != -1) :
+      "Called freePage() on memory that wasn't allocated with allocatePage()";
+    assert(allocatedPages.get(page.pageNumber));
+    pageTable[page.pageNumber] = null;
+    synchronized (this) {
+      allocatedPages.clear(page.pageNumber);
+    }
+    if (logger.isTraceEnabled()) {
+      logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
+    }
+    long pageSize = page.size();
+    memoryManager.tungstenMemoryAllocator().free(page);
+    releaseExecutionMemory(pageSize);
+  }
+
+  /**
+   * Given a memory page and offset within that page, encode this address into a 64-bit long.
+   * This address will remain valid as long as the corresponding page has not been freed.
+   *
+   * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/
+   * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
+   *                     this should be the value that you would pass as the base offset into an
+   *                     UNSAFE call (e.g. page.baseOffset() + something).
+   * @return an encoded page address.
+   */
+  public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
+    if (!inHeap) {
+      // In off-heap mode, an offset is an absolute address that may require a full 64 bits to
+      // encode. Due to our page size limitation, though, we can convert this into an offset that's
+      // relative to the page's base offset; this relative offset will fit in 51 bits.
+      offsetInPage -= page.getBaseOffset();
+    }
+    return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
+  }
+
+  @VisibleForTesting
+  public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
+    assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+    return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+  }
+
+  @VisibleForTesting
+  public static int decodePageNumber(long pagePlusOffsetAddress) {
+    return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS);
+  }
+
+  private static long decodeOffset(long pagePlusOffsetAddress) {
+    return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
+  }
+
+  /**
+   * Get the page associated with an address encoded by
+   * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
+   */
+  public Object getPage(long pagePlusOffsetAddress) {
+    if (inHeap) {
+      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+      assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+      final MemoryBlock page = pageTable[pageNumber];
+      assert (page != null);
+      assert (page.getBaseObject() != null);
+      return page.getBaseObject();
+    } else {
+      return null;
+    }
+  }
+
+  /**
+   * Get the offset associated with an address encoded by
+   * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
+   */
+  public long getOffsetInPage(long pagePlusOffsetAddress) {
+    final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
+    if (inHeap) {
+      return offsetInPage;
+    } else {
+      // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we
+      // converted the absolute address into a relative address. Here, we invert that operation:
+      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+      assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+      final MemoryBlock page = pageTable[pageNumber];
+      assert (page != null);
+      return page.getBaseOffset() + offsetInPage;
+    }
+  }
+
+  /**
+   * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return
+   * value can be used to detect memory leaks.
+   */
+  public long cleanUpAllAllocatedMemory() {
+    long freedBytes = 0;
+    for (MemoryBlock page : pageTable) {
+      if (page != null) {
+        freedBytes += page.size();
+        freePage(page);
+      }
+    }
+
+    freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
+
+    return freedBytes;
+  }
+
+  /**
+   * Returns the memory consumption, in bytes, for the current task
+   */
+  public long getMemoryConsumptionForThisTask() {
+    return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
index c117119..f8f2b22 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.shuffle.sort;
 
+import org.apache.spark.memory.TaskMemoryManager;
+
 /**
  * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
  * <p>
@@ -26,7 +28,7 @@ package org.apache.spark.shuffle.sort;
  * </pre>
  * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
  * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
- * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
+ * 13-bit page numbers assigned by {@link TaskMemoryManager}), this
  * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
  * <p>
  * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 85fdaa8..f43236f 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -33,14 +33,13 @@ import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.TempShuffleBlockId;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 /**
@@ -72,7 +71,6 @@ final class ShuffleExternalSorter {
   @VisibleForTesting
   final int maxRecordSizeBytes;
   private final TaskMemoryManager taskMemoryManager;
-  private final ShuffleMemoryManager shuffleMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
   private final ShuffleWriteMetrics writeMetrics;
@@ -105,7 +103,6 @@ final class ShuffleExternalSorter {
 
   public ShuffleExternalSorter(
       TaskMemoryManager memoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       BlockManager blockManager,
       TaskContext taskContext,
       int initialSize,
@@ -113,7 +110,6 @@ final class ShuffleExternalSorter {
       SparkConf conf,
       ShuffleWriteMetrics writeMetrics) throws IOException {
     this.taskMemoryManager = memoryManager;
-    this.shuffleMemoryManager = shuffleMemoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
     this.initialSize = initialSize;
@@ -124,7 +120,7 @@ final class ShuffleExternalSorter {
     this.numElementsForSpillThreshold =
       conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
     this.pageSizeBytes = (int) Math.min(
-      PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
+      PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes());
     this.maxRecordSizeBytes = pageSizeBytes - 4;
     this.writeMetrics = writeMetrics;
     initializeForWriting();
@@ -140,9 +136,9 @@ final class ShuffleExternalSorter {
   private void initializeForWriting() throws IOException {
     // TODO: move this sizing calculation logic into a static method of sorter:
     final long memoryRequested = initialSize * 8L;
-    final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+    final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested);
     if (memoryAcquired != memoryRequested) {
-      shuffleMemoryManager.release(memoryAcquired);
+      taskMemoryManager.releaseExecutionMemory(memoryAcquired);
       throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
     }
 
@@ -272,6 +268,7 @@ final class ShuffleExternalSorter {
    */
   @VisibleForTesting
   void spill() throws IOException {
+    assert(inMemSorter != null);
     logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
       Thread.currentThread().getId(),
       Utils.bytesToString(getMemoryUsage()),
@@ -281,7 +278,7 @@ final class ShuffleExternalSorter {
     writeSortedFile(false);
     final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
     inMemSorter = null;
-    shuffleMemoryManager.release(inMemSorterMemoryUsage);
+    taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage);
     final long spillSize = freeMemory();
     taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
 
@@ -316,9 +313,13 @@ final class ShuffleExternalSorter {
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
       taskMemoryManager.freePage(block);
-      shuffleMemoryManager.release(block.size());
       memoryFreed += block.size();
     }
+    if (inMemSorter != null) {
+      long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+      inMemSorter = null;
+      taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
+    }
     allocatedPages.clear();
     currentPage = null;
     currentPagePosition = -1;
@@ -337,8 +338,9 @@ final class ShuffleExternalSorter {
       }
     }
     if (inMemSorter != null) {
-      shuffleMemoryManager.release(inMemSorter.getMemoryUsage());
+      long sorterMemoryUsage = inMemSorter.getMemoryUsage();
       inMemSorter = null;
+      taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
     }
   }
 
@@ -353,21 +355,20 @@ final class ShuffleExternalSorter {
       logger.debug("Attempting to expand sort pointer array");
       final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
       final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
-      final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+      final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray);
       if (memoryAcquired < memoryToGrowPointerArray) {
-        shuffleMemoryManager.release(memoryAcquired);
+        taskMemoryManager.releaseExecutionMemory(memoryAcquired);
         spill();
       } else {
         inMemSorter.expandPointerArray();
-        shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+        taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage);
       }
     }
   }
-  
+
   /**
    * Allocates more memory in order to insert an additional record. This will request additional
-   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
-   * obtained.
+   * memory from the memory manager and spill if the requested memory can not be obtained.
    *
    * @param requiredSpace the required space in the data page, in bytes, including space for storing
    *                      the record size. This must be less than or equal to the page size (records
@@ -386,17 +387,14 @@ final class ShuffleExternalSorter {
         throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
           pageSizeBytes + ")");
       } else {
-        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-        if (memoryAcquired < pageSizeBytes) {
-          shuffleMemoryManager.release(memoryAcquired);
+        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+        if (currentPage == null) {
           spill();
-          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-          if (memoryAcquiredAfterSpilling != pageSizeBytes) {
-            shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+          currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+          if (currentPage == null) {
             throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
           }
         }
-        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
         currentPagePosition = currentPage.getBaseOffset();
         freeSpaceInCurrentPage = pageSizeBytes;
         allocatedPages.add(currentPage);
@@ -430,17 +428,14 @@ final class ShuffleExternalSorter {
       long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
       // The record is larger than the page size, so allocate a special overflow page just to hold
       // that record.
-      final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
-      if (memoryGranted != overflowPageSize) {
-        shuffleMemoryManager.release(memoryGranted);
+      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+      if (overflowPage == null) {
         spill();
-        final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
-        if (memoryGrantedAfterSpill != overflowPageSize) {
-          shuffleMemoryManager.release(memoryGrantedAfterSpill);
+        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+        if (overflowPage == null) {
           throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
         }
       }
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
       allocatedPages.add(overflowPage);
       dataPage = overflowPage;
       dataPagePosition = overflowPage.getBaseOffset();

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index e8f050c..f6c5c94 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -49,12 +49,11 @@ import org.apache.spark.serializer.SerializationStream;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.TimeTrackingOutputStream;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 
 @Private
 public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -69,7 +68,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   private final BlockManager blockManager;
   private final IndexShuffleBlockResolver shuffleBlockResolver;
   private final TaskMemoryManager memoryManager;
-  private final ShuffleMemoryManager shuffleMemoryManager;
   private final SerializerInstance serializer;
   private final Partitioner partitioner;
   private final ShuffleWriteMetrics writeMetrics;
@@ -103,7 +101,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       BlockManager blockManager,
       IndexShuffleBlockResolver shuffleBlockResolver,
       TaskMemoryManager memoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       SerializedShuffleHandle<K, V> handle,
       int mapId,
       TaskContext taskContext,
@@ -117,7 +114,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     this.blockManager = blockManager;
     this.shuffleBlockResolver = shuffleBlockResolver;
     this.memoryManager = memoryManager;
-    this.shuffleMemoryManager = shuffleMemoryManager;
     this.mapId = mapId;
     final ShuffleDependency<K, V, V> dep = handle.dependency();
     this.shuffleId = dep.shuffleId();
@@ -197,7 +193,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     assert (sorter == null);
     sorter = new ShuffleExternalSorter(
       memoryManager,
-      shuffleMemoryManager,
       blockManager,
       taskContext,
       INITIAL_SORT_BUFFER_SIZE,

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index b24eed3..f035bda 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -26,7 +26,6 @@ import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.array.LongArray;
@@ -34,7 +33,7 @@ import org.apache.spark.unsafe.bitset.BitSet;
 import org.apache.spark.unsafe.hash.Murmur3_x86_32;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 
 /**
  * An append-only hash map where keys and values are contiguous regions of bytes.
@@ -70,8 +69,6 @@ public final class BytesToBytesMap {
 
   private final TaskMemoryManager taskMemoryManager;
 
-  private final ShuffleMemoryManager shuffleMemoryManager;
-
   /**
    * A linked list for tracking all allocated data pages so that we can free all of our memory.
    */
@@ -169,13 +166,11 @@ public final class BytesToBytesMap {
 
   public BytesToBytesMap(
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       int initialCapacity,
       double loadFactor,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
     this.taskMemoryManager = taskMemoryManager;
-    this.shuffleMemoryManager = shuffleMemoryManager;
     this.loadFactor = loadFactor;
     this.loc = new Location();
     this.pageSizeBytes = pageSizeBytes;
@@ -201,21 +196,18 @@ public final class BytesToBytesMap {
 
   public BytesToBytesMap(
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       int initialCapacity,
       long pageSizeBytes) {
-    this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+    this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
   }
 
   public BytesToBytesMap(
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       int initialCapacity,
       long pageSizeBytes,
       boolean enablePerfMetrics) {
     this(
       taskMemoryManager,
-      shuffleMemoryManager,
       initialCapacity,
       0.70,
       pageSizeBytes,
@@ -260,7 +252,6 @@ public final class BytesToBytesMap {
       if (destructive && currentPage != null) {
         dataPagesIterator.remove();
         this.bmap.taskMemoryManager.freePage(currentPage);
-        this.bmap.shuffleMemoryManager.release(currentPage.size());
       }
       currentPage = dataPagesIterator.next();
       pageBaseObject = currentPage.getBaseObject();
@@ -572,14 +563,12 @@ public final class BytesToBytesMap {
       if (useOverflowPage) {
         // The record is larger than the page size, so allocate a special overflow page just to hold
         // that record.
-        final long memoryRequested = requiredSize + 8;
-        final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested);
-        if (memoryGranted != memoryRequested) {
-          shuffleMemoryManager.release(memoryGranted);
-          logger.debug("Failed to acquire {} bytes of memory", memoryRequested);
+        final long overflowPageSize = requiredSize + 8;
+        MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+        if (overflowPage == null) {
+          logger.debug("Failed to acquire {} bytes of memory", overflowPageSize);
           return false;
         }
-        MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested);
         dataPages.add(overflowPage);
         dataPage = overflowPage;
         dataPageBaseObject = overflowPage.getBaseObject();
@@ -655,17 +644,15 @@ public final class BytesToBytesMap {
   }
 
   /**
-   * Acquire a new page from the {@link ShuffleMemoryManager}.
+   * Acquire a new page from the memory manager.
    * @return whether there is enough space to allocate the new page.
    */
   private boolean acquireNewPage() {
-    final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-    if (memoryGranted != pageSizeBytes) {
-      shuffleMemoryManager.release(memoryGranted);
+    MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
+    if (newPage == null) {
       logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
       return false;
     }
-    MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
     dataPages.add(newPage);
     pageCursor = 0;
     currentDataPage = newPage;
@@ -705,7 +692,6 @@ public final class BytesToBytesMap {
       MemoryBlock dataPage = dataPagesIterator.next();
       dataPagesIterator.remove();
       taskMemoryManager.freePage(dataPage);
-      shuffleMemoryManager.release(dataPage.size());
     }
     assert(dataPages.isEmpty());
   }
@@ -714,10 +700,6 @@ public final class BytesToBytesMap {
     return taskMemoryManager;
   }
 
-  public ShuffleMemoryManager getShuffleMemoryManager() {
-    return shuffleMemoryManager;
-  }
-
   public long getPageSizeBytes() {
     return pageSizeBytes;
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
index 0c4ebde..dbf6770 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
@@ -17,9 +17,11 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import org.apache.spark.memory.TaskMemoryManager;
+
 final class RecordPointerAndKeyPrefix {
   /**
-   * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
+   * A pointer to a record; see {@link TaskMemoryManager} for a
    * description of how these addresses are encoded.
    */
   public long recordPointer;

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 0a311d2..e317ea3 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -32,12 +32,11 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 /**
@@ -52,7 +51,6 @@ public final class UnsafeExternalSorter {
   private final RecordComparator recordComparator;
   private final int initialSize;
   private final TaskMemoryManager taskMemoryManager;
-  private final ShuffleMemoryManager shuffleMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
   private ShuffleWriteMetrics writeMetrics;
@@ -82,7 +80,6 @@ public final class UnsafeExternalSorter {
 
   public static UnsafeExternalSorter createWithExistingInMemorySorter(
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       BlockManager blockManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
@@ -90,26 +87,24 @@ public final class UnsafeExternalSorter {
       int initialSize,
       long pageSizeBytes,
       UnsafeInMemorySorter inMemorySorter) throws IOException {
-    return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+    return new UnsafeExternalSorter(taskMemoryManager, blockManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
   }
 
   public static UnsafeExternalSorter create(
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       BlockManager blockManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
       PrefixComparator prefixComparator,
       int initialSize,
       long pageSizeBytes) throws IOException {
-    return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+    return new UnsafeExternalSorter(taskMemoryManager, blockManager,
       taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
   }
 
   private UnsafeExternalSorter(
       TaskMemoryManager taskMemoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
       BlockManager blockManager,
       TaskContext taskContext,
       RecordComparator recordComparator,
@@ -118,7 +113,6 @@ public final class UnsafeExternalSorter {
       long pageSizeBytes,
       @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
     this.taskMemoryManager = taskMemoryManager;
-    this.shuffleMemoryManager = shuffleMemoryManager;
     this.blockManager = blockManager;
     this.taskContext = taskContext;
     this.recordComparator = recordComparator;
@@ -261,7 +255,6 @@ public final class UnsafeExternalSorter {
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
       taskMemoryManager.freePage(block);
-      shuffleMemoryManager.release(block.size());
       memoryFreed += block.size();
     }
     // TODO: track in-memory sorter memory usage (SPARK-10474)
@@ -309,8 +302,7 @@ public final class UnsafeExternalSorter {
 
   /**
    * Allocates more memory in order to insert an additional record. This will request additional
-   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
-   * obtained.
+   * memory from the memory manager and spill if the requested memory can not be obtained.
    *
    * @param requiredSpace the required space in the data page, in bytes, including space for storing
    *                      the record size. This must be less than or equal to the page size (records
@@ -335,23 +327,20 @@ public final class UnsafeExternalSorter {
   }
 
   /**
-   * Acquire a new page from the {@link ShuffleMemoryManager}.
+   * Acquire a new page from the memory manager.
    *
    * If there is not enough space to allocate the new page, spill all existing ones
    * and try again. If there is still not enough space, report error to the caller.
    */
   private void acquireNewPage() throws IOException {
-    final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-    if (memoryAcquired < pageSizeBytes) {
-      shuffleMemoryManager.release(memoryAcquired);
+    currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+    if (currentPage == null) {
       spill();
-      final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-      if (memoryAcquiredAfterSpilling != pageSizeBytes) {
-        shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+      currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+      if (currentPage == null) {
         throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
       }
     }
-    currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
     currentPagePosition = currentPage.getBaseOffset();
     freeSpaceInCurrentPage = pageSizeBytes;
     allocatedPages.add(currentPage);
@@ -379,17 +368,14 @@ public final class UnsafeExternalSorter {
       long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
       // The record is larger than the page size, so allocate a special overflow page just to hold
       // that record.
-      final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
-      if (memoryGranted != overflowPageSize) {
-        shuffleMemoryManager.release(memoryGranted);
+      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+      if (overflowPage == null) {
         spill();
-        final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
-        if (memoryGrantedAfterSpill != overflowPageSize) {
-          shuffleMemoryManager.release(memoryGrantedAfterSpill);
+        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+        if (overflowPage == null) {
           throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
         }
       }
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
       allocatedPages.add(overflowPage);
       dataPage = overflowPage;
       dataPagePosition = overflowPage.getBaseOffset();
@@ -441,17 +427,14 @@ public final class UnsafeExternalSorter {
       long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
       // The record is larger than the page size, so allocate a special overflow page just to hold
       // that record.
-      final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
-      if (memoryGranted != overflowPageSize) {
-        shuffleMemoryManager.release(memoryGranted);
+      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+      if (overflowPage == null) {
         spill();
-        final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
-        if (memoryGrantedAfterSpill != overflowPageSize) {
-          shuffleMemoryManager.release(memoryGrantedAfterSpill);
+        overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+        if (overflowPage == null) {
           throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
         }
       }
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
       allocatedPages.add(overflowPage);
       dataPage = overflowPage;
       dataPagePosition = overflowPage.getBaseOffset();

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index f7787e1..5aad72c 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -21,7 +21,7 @@ import java.util.Comparator;
 
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.util.collection.Sorter;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 
 /**
  * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index b5c35c5..398e093 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -38,9 +38,8 @@ import org.apache.spark.rpc.akka.AkkaRpcEnv
 import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus}
 import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
+import org.apache.spark.shuffle.ShuffleManager
 import org.apache.spark.storage._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
 import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
 
 /**
@@ -70,10 +69,7 @@ class SparkEnv (
     val httpFileServer: HttpFileServer,
     val sparkFilesDir: String,
     val metricsSystem: MetricsSystem,
-    // TODO: unify these *MemoryManager classes (SPARK-10984)
     val memoryManager: MemoryManager,
-    val shuffleMemoryManager: ShuffleMemoryManager,
-    val executorMemoryManager: ExecutorMemoryManager,
     val outputCommitCoordinator: OutputCommitCoordinator,
     val conf: SparkConf) extends Logging {
 
@@ -340,13 +336,11 @@ object SparkEnv extends Logging {
     val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false)
     val memoryManager: MemoryManager =
       if (useLegacyMemoryManager) {
-        new StaticMemoryManager(conf)
+        new StaticMemoryManager(conf, numUsableCores)
       } else {
-        new UnifiedMemoryManager(conf)
+        new UnifiedMemoryManager(conf, numUsableCores)
       }
 
-    val shuffleMemoryManager = ShuffleMemoryManager.create(conf, memoryManager, numUsableCores)
-
     val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores)
 
     val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
@@ -405,15 +399,6 @@ object SparkEnv extends Logging {
       new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
     outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
 
-    val executorMemoryManager: ExecutorMemoryManager = {
-      val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
-        MemoryAllocator.UNSAFE
-      } else {
-        MemoryAllocator.HEAP
-      }
-      new ExecutorMemoryManager(allocator)
-    }
-
     val envInstance = new SparkEnv(
       executorId,
       rpcEnv,
@@ -431,8 +416,6 @@ object SparkEnv extends Logging {
       sparkFilesDir,
       metricsSystem,
       memoryManager,
-      shuffleMemoryManager,
-      executorMemoryManager,
       outputCommitCoordinator,
       conf)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 63cca80..af558d6 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,8 +21,8 @@ import java.io.Serializable
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.source.Source
-import org.apache.spark.unsafe.memory.TaskMemoryManager
 import org.apache.spark.util.TaskCompletionListener
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 5df94c6..f0ae83a 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -20,9 +20,9 @@ package org.apache.spark
 import scala.collection.mutable.{ArrayBuffer, HashMap}
 
 import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.metrics.source.Source
-import org.apache.spark.unsafe.memory.TaskMemoryManager
 import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
 
 private[spark] class TaskContextImpl(

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c3491bb..9e88d48 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -29,10 +29,10 @@ import scala.util.control.NonFatal
 
 import org.apache.spark._
 import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.unsafe.memory.TaskMemoryManager
 import org.apache.spark.util._
 
 /**
@@ -179,7 +179,7 @@ private[spark] class Executor(
     }
 
     override def run(): Unit = {
-      val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
+      val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
       val deserializeStartTime = System.currentTimeMillis()
       Thread.currentThread.setContextClassLoader(replClassLoader)
       val ser = env.closureSerializer.newInstance()

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
index 7168ac5..6c9a71c 100644
--- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
@@ -17,20 +17,38 @@
 
 package org.apache.spark.memory
 
+import javax.annotation.concurrent.GuardedBy
+
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.Logging
-import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
+import com.google.common.annotations.VisibleForTesting
 
+import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging}
+import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
+import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.memory.MemoryAllocator
 
 /**
  * An abstract memory manager that enforces how memory is shared between execution and storage.
  *
  * In this context, execution memory refers to that used for computation in shuffles, joins,
  * sorts and aggregations, while storage memory refers to that used for caching and propagating
- * internal data across the cluster. There exists one of these per JVM.
+ * internal data across the cluster. There exists one MemoryManager per JVM.
+ *
+ * The MemoryManager abstract base class itself implements policies for sharing execution memory
+ * between tasks; it tries to ensure that each task gets a reasonable share of memory, instead of
+ * some task ramping up to a large amount first and then causing others to spill to disk repeatedly.
+ * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory
+ * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
+ * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
+ * this set changes. This is all done by synchronizing access to mutable state and using wait() and
+ * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across
+ * tasks was performed by the ShuffleMemoryManager.
  */
-private[spark] abstract class MemoryManager extends Logging {
+private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) extends Logging {
+
+  // -- Methods related to memory allocation policies and bookkeeping ------------------------------
 
   // The memory store used to evict cached blocks
   private var _memoryStore: MemoryStore = _
@@ -42,8 +60,10 @@ private[spark] abstract class MemoryManager extends Logging {
   }
 
   // Amount of execution/storage memory in use, accesses must be synchronized on `this`
-  protected var _executionMemoryUsed: Long = 0
-  protected var _storageMemoryUsed: Long = 0
+  @GuardedBy("this") protected var _executionMemoryUsed: Long = 0
+  @GuardedBy("this") protected var _storageMemoryUsed: Long = 0
+  // Map from taskAttemptId -> memory consumption in bytes
+  @GuardedBy("this") private val executionMemoryForTask = new mutable.HashMap[Long, Long]()
 
   /**
    * Set the [[MemoryStore]] used by this manager to evict cached blocks.
@@ -66,15 +86,6 @@ private[spark] abstract class MemoryManager extends Logging {
   // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985)
 
   /**
-   * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
-   * Blocks evicted in the process, if any, are added to `evictedBlocks`.
-   * @return number of bytes successfully granted (<= N).
-   */
-  def acquireExecutionMemory(
-      numBytes: Long,
-      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long
-
-  /**
    * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
    * Blocks evicted in the process, if any, are added to `evictedBlocks`.
    * @return whether all N bytes were successfully granted.
@@ -102,9 +113,92 @@ private[spark] abstract class MemoryManager extends Logging {
   }
 
   /**
-   * Release N bytes of execution memory.
+   * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
+   * Blocks evicted in the process, if any, are added to `evictedBlocks`.
+   * @return number of bytes successfully granted (<= N).
+   */
+  @VisibleForTesting
+  private[memory] def doAcquireExecutionMemory(
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long
+
+  /**
+   * Try to acquire up to `numBytes` of execution memory for the current task and return the number
+   * of bytes obtained, or 0 if none can be allocated.
+   *
+   * This call may block until there is enough free memory in some situations, to make sure each
+   * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of
+   * active tasks) before it is forced to spill. This can happen if the number of tasks increase
+   * but an older task had a lot of memory already.
+   *
+   * Subclasses should override `doAcquireExecutionMemory` in order to customize the policies
+   * that control global sharing of memory between execution and storage.
    */
-  def releaseExecutionMemory(numBytes: Long): Unit = synchronized {
+  private[memory]
+  final def acquireExecutionMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized {
+    assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
+
+    // Add this task to the taskMemory map just so we can keep an accurate count of the number
+    // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
+    if (!executionMemoryForTask.contains(taskAttemptId)) {
+      executionMemoryForTask(taskAttemptId) = 0L
+      // This will later cause waiting tasks to wake up and check numTasks again
+      notifyAll()
+    }
+
+    // Once the cross-task memory allocation policy has decided to grant more memory to a task,
+    // this method is called in order to actually obtain that execution memory, potentially
+    // triggering eviction of storage memory:
+    def acquire(toGrant: Long): Long = synchronized {
+      val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+      val acquired = doAcquireExecutionMemory(toGrant, evictedBlocks)
+      // Register evicted blocks, if any, with the active task metrics
+      Option(TaskContext.get()).foreach { tc =>
+        val metrics = tc.taskMetrics()
+        val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
+        metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq)
+      }
+      executionMemoryForTask(taskAttemptId) += acquired
+      acquired
+    }
+
+    // Keep looping until we're either sure that we don't want to grant this request (because this
+    // task would have more than 1 / numActiveTasks of the memory) or we have enough free
+    // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
+    // TODO: simplify this to limit each task to its own slot
+    while (true) {
+      val numActiveTasks = executionMemoryForTask.keys.size
+      val curMem = executionMemoryForTask(taskAttemptId)
+      val freeMemory = maxExecutionMemory - executionMemoryForTask.values.sum
+
+      // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
+      // don't let it be negative
+      val maxToGrant =
+        math.min(numBytes, math.max(0, (maxExecutionMemory / numActiveTasks) - curMem))
+      // Only give it as much memory as is free, which might be none if it reached 1 / numTasks
+      val toGrant = math.min(maxToGrant, freeMemory)
+
+      if (curMem < maxExecutionMemory / (2 * numActiveTasks)) {
+        // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
+        // if we can't give it this much now, wait for other tasks to free up memory
+        // (this happens if older tasks allocated lots of memory before N grew)
+        if (
+          freeMemory >= math.min(maxToGrant, maxExecutionMemory / (2 * numActiveTasks) - curMem)) {
+          return acquire(toGrant)
+        } else {
+          logInfo(
+            s"TID $taskAttemptId waiting for at least 1/2N of execution memory pool to be free")
+          wait()
+        }
+      } else {
+        return acquire(toGrant)
+      }
+    }
+    0L  // Never reached
+  }
+
+  @VisibleForTesting
+  private[memory] def releaseExecutionMemory(numBytes: Long): Unit = synchronized {
     if (numBytes > _executionMemoryUsed) {
       logWarning(s"Attempted to release $numBytes bytes of execution " +
         s"memory when we only have ${_executionMemoryUsed} bytes")
@@ -115,6 +209,36 @@ private[spark] abstract class MemoryManager extends Logging {
   }
 
   /**
+   * Release numBytes of execution memory belonging to the given task.
+   */
+  private[memory]
+  final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized {
+    val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L)
+    if (curMem < numBytes) {
+      throw new SparkException(
+        s"Internal error: release called on $numBytes bytes but task only has $curMem")
+    }
+    if (executionMemoryForTask.contains(taskAttemptId)) {
+      executionMemoryForTask(taskAttemptId) -= numBytes
+      if (executionMemoryForTask(taskAttemptId) <= 0) {
+        executionMemoryForTask.remove(taskAttemptId)
+      }
+      releaseExecutionMemory(numBytes)
+    }
+    notifyAll() // Notify waiters in acquireExecutionMemory() that memory has been freed
+  }
+
+  /**
+   * Release all memory for the given task and mark it as inactive (e.g. when a task ends).
+   * @return the number of bytes freed.
+   */
+  private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized {
+    val numBytesToFree = getExecutionMemoryUsageForTask(taskAttemptId)
+    releaseExecutionMemory(numBytesToFree, taskAttemptId)
+    numBytesToFree
+  }
+
+  /**
    * Release N bytes of storage memory.
    */
   def releaseStorageMemory(numBytes: Long): Unit = synchronized {
@@ -155,4 +279,43 @@ private[spark] abstract class MemoryManager extends Logging {
     _storageMemoryUsed
   }
 
+  /**
+   * Returns the execution memory consumption, in bytes, for the given task.
+   */
+  private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized {
+    executionMemoryForTask.getOrElse(taskAttemptId, 0L)
+  }
+
+  // -- Fields related to Tungsten managed memory -------------------------------------------------
+
+  /**
+   * The default page size, in bytes.
+   *
+   * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value
+   * by looking at the number of cores available to the process, and the total amount of memory,
+   * and then divide it by a factor of safety.
+   */
+  val pageSizeBytes: Long = {
+    val minPageSize = 1L * 1024 * 1024   // 1MB
+    val maxPageSize = 64L * minPageSize  // 64MB
+    val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors()
+    // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case
+    val safetyFactor = 16
+    val size = ByteArrayMethods.nextPowerOf2(maxExecutionMemory / cores / safetyFactor)
+    val default = math.min(maxPageSize, math.max(minPageSize, size))
+    conf.getSizeAsBytes("spark.buffer.pageSize", default)
+  }
+
+  /**
+   * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using
+   * sun.misc.Unsafe.
+   */
+  final val tungstenMemoryIsAllocatedInHeap: Boolean =
+    !conf.getBoolean("spark.unsafe.offHeap", false)
+
+  /**
+   * Allocates memory for use by Unsafe/Tungsten code.
+   */
+  private[memory] final val tungstenMemoryAllocator: MemoryAllocator =
+    if (tungstenMemoryIsAllocatedInHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
index fa44f37..9c2c2e9 100644
--- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
@@ -33,14 +33,16 @@ import org.apache.spark.storage.{BlockId, BlockStatus}
 private[spark] class StaticMemoryManager(
     conf: SparkConf,
     override val maxExecutionMemory: Long,
-    override val maxStorageMemory: Long)
-  extends MemoryManager {
+    override val maxStorageMemory: Long,
+    numCores: Int)
+  extends MemoryManager(conf, numCores) {
 
-  def this(conf: SparkConf) {
+  def this(conf: SparkConf, numCores: Int) {
     this(
       conf,
       StaticMemoryManager.getMaxExecutionMemory(conf),
-      StaticMemoryManager.getMaxStorageMemory(conf))
+      StaticMemoryManager.getMaxStorageMemory(conf),
+      numCores)
   }
 
   // Max number of bytes worth of blocks to evict when unrolling
@@ -52,7 +54,7 @@ private[spark] class StaticMemoryManager(
    * Acquire N bytes of memory for execution.
    * @return number of bytes successfully granted (<= N).
    */
-  override def acquireExecutionMemory(
+  override def doAcquireExecutionMemory(
       numBytes: Long,
       evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
     assert(numBytes >= 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
index 5bf78d5..a309303 100644
--- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
@@ -42,10 +42,14 @@ import org.apache.spark.storage.{BlockStatus, BlockId}
  * up most of the storage space, in which case the new blocks will be evicted immediately
  * according to their respective storage levels.
  */
-private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) extends MemoryManager {
+private[spark] class UnifiedMemoryManager(
+    conf: SparkConf,
+    maxMemory: Long,
+    numCores: Int)
+  extends MemoryManager(conf, numCores) {
 
-  def this(conf: SparkConf) {
-    this(conf, UnifiedMemoryManager.getMaxMemory(conf))
+  def this(conf: SparkConf, numCores: Int) {
+    this(conf, UnifiedMemoryManager.getMaxMemory(conf), numCores)
   }
 
   /**
@@ -91,7 +95,7 @@ private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) exte
    * Blocks evicted in the process, if any, are added to `evictedBlocks`.
    * @return number of bytes successfully granted (<= N).
    */
-  override def acquireExecutionMemory(
+  private[memory] override def doAcquireExecutionMemory(
       numBytes: Long,
       evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
     assert(numBytes >= 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 9edf9f0..4fb32ba 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -25,8 +25,8 @@ import scala.collection.mutable.HashMap
 import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext}
 import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.serializer.SerializerInstance
-import org.apache.spark.unsafe.memory.TaskMemoryManager
 import org.apache.spark.util.ByteBufferInputStream
 import org.apache.spark.util.Utils
 
@@ -90,10 +90,6 @@ private[spark] abstract class Task[T](
       context.markTaskCompleted()
       try {
         Utils.tryLogNonFatalError {
-          // Release memory used by this thread for shuffles
-          SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
-        }
-        Utils.tryLogNonFatalError {
           // Release memory used by this thread for unrolling blocks
           SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 7c3e2b5..b0abda4 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -98,13 +98,14 @@ private[spark] class BlockStoreShuffleReader[K, C](
       case Some(keyOrd: Ordering[K]) =>
         // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
         // the ExternalSorter won't spill to disk.
-        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
+        val sorter =
+          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
         sorter.insertAll(aggregatedIter)
         context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
         context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
         context.internalMetricsToAccumulators(
           InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
-        sorter.iterator
+        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
       case None =>
         aggregatedIter
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
deleted file mode 100644
index 9bd18da..0000000
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ /dev/null
@@ -1,209 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle
-
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-
-import com.google.common.annotations.VisibleForTesting
-
-import org.apache.spark._
-import org.apache.spark.memory.{StaticMemoryManager, MemoryManager}
-import org.apache.spark.storage.{BlockId, BlockStatus}
-import org.apache.spark.unsafe.array.ByteArrayMethods
-
-/**
- * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling
- * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
- * from this pool and release it as it spills data out. When a task ends, all its memory will be
- * released by the Executor.
- *
- * This class tries to ensure that each task gets a reasonable share of memory, instead of some
- * task ramping up to a large amount first and then causing others to spill to disk repeatedly.
- * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory
- * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
- * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
- * this set changes. This is all done by synchronizing access to `memoryManager` to mutate state
- * and using wait() and notifyAll() to signal changes.
- *
- * Use `ShuffleMemoryManager.create()` factory method to create a new instance.
- *
- * @param memoryManager the interface through which this manager acquires execution memory
- * @param pageSizeBytes number of bytes for each page, by default.
- */
-private[spark]
-class ShuffleMemoryManager protected (
-    memoryManager: MemoryManager,
-    val pageSizeBytes: Long)
-  extends Logging {
-
-  private val taskMemory = new mutable.HashMap[Long, Long]()  // taskAttemptId -> memory bytes
-
-  private def currentTaskAttemptId(): Long = {
-    // In case this is called on the driver, return an invalid task attempt id.
-    Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
-  }
-
-  /**
-   * Try to acquire up to numBytes memory for the current task, and return the number of bytes
-   * obtained, or 0 if none can be allocated. This call may block until there is enough free memory
-   * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the
-   * total memory pool (where N is the # of active tasks) before it is forced to spill. This can
-   * happen if the number of tasks increases but an older task had a lot of memory already.
-   */
-  def tryToAcquire(numBytes: Long): Long = memoryManager.synchronized {
-    val taskAttemptId = currentTaskAttemptId()
-    assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
-
-    // Add this task to the taskMemory map just so we can keep an accurate count of the number
-    // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
-    if (!taskMemory.contains(taskAttemptId)) {
-      taskMemory(taskAttemptId) = 0L
-      // This will later cause waiting tasks to wake up and check numTasks again
-      memoryManager.notifyAll()
-    }
-
-    // Keep looping until we're either sure that we don't want to grant this request (because this
-    // task would have more than 1 / numActiveTasks of the memory) or we have enough free
-    // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
-    // TODO: simplify this to limit each task to its own slot
-    while (true) {
-      val numActiveTasks = taskMemory.keys.size
-      val curMem = taskMemory(taskAttemptId)
-      val maxMemory = memoryManager.maxExecutionMemory
-      val freeMemory = maxMemory - taskMemory.values.sum
-
-      // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
-      // don't let it be negative
-      val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem))
-      // Only give it as much memory as is free, which might be none if it reached 1 / numTasks
-      val toGrant = math.min(maxToGrant, freeMemory)
-
-      if (curMem < maxMemory / (2 * numActiveTasks)) {
-        // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
-        // if we can't give it this much now, wait for other tasks to free up memory
-        // (this happens if older tasks allocated lots of memory before N grew)
-        if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) {
-          return acquire(toGrant)
-        } else {
-          logInfo(
-            s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
-          memoryManager.wait()
-        }
-      } else {
-        return acquire(toGrant)
-      }
-    }
-    0L  // Never reached
-  }
-
-  /**
-   * Acquire N bytes of execution memory from the memory manager for the current task.
-   * @return number of bytes actually acquired (<= N).
-   */
-  private def acquire(numBytes: Long): Long = memoryManager.synchronized {
-    val taskAttemptId = currentTaskAttemptId()
-    val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
-    val acquired = memoryManager.acquireExecutionMemory(numBytes, evictedBlocks)
-    // Register evicted blocks, if any, with the active task metrics
-    // TODO: just do this in `acquireExecutionMemory` (SPARK-10985)
-    Option(TaskContext.get()).foreach { tc =>
-      val metrics = tc.taskMetrics()
-      val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
-      metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq)
-    }
-    taskMemory(taskAttemptId) += acquired
-    acquired
-  }
-
-  /** Release numBytes bytes for the current task. */
-  def release(numBytes: Long): Unit = memoryManager.synchronized {
-    val taskAttemptId = currentTaskAttemptId()
-    val curMem = taskMemory.getOrElse(taskAttemptId, 0L)
-    if (curMem < numBytes) {
-      throw new SparkException(
-        s"Internal error: release called on $numBytes bytes but task only has $curMem")
-    }
-    if (taskMemory.contains(taskAttemptId)) {
-      taskMemory(taskAttemptId) -= numBytes
-      memoryManager.releaseExecutionMemory(numBytes)
-    }
-    memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed
-  }
-
-  /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */
-  def releaseMemoryForThisTask(): Unit = memoryManager.synchronized {
-    val taskAttemptId = currentTaskAttemptId()
-    taskMemory.remove(taskAttemptId).foreach { numBytes =>
-      memoryManager.releaseExecutionMemory(numBytes)
-    }
-    memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed
-  }
-
-  /** Returns the memory consumption, in bytes, for the current task */
-  def getMemoryConsumptionForThisTask(): Long = memoryManager.synchronized {
-    val taskAttemptId = currentTaskAttemptId()
-    taskMemory.getOrElse(taskAttemptId, 0L)
-  }
-}
-
-
-private[spark] object ShuffleMemoryManager {
-
-  def create(
-      conf: SparkConf,
-      memoryManager: MemoryManager,
-      numCores: Int): ShuffleMemoryManager = {
-    val maxMemory = memoryManager.maxExecutionMemory
-    val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores)
-    new ShuffleMemoryManager(memoryManager, pageSize)
-  }
-
-  /**
-   * Create a dummy [[ShuffleMemoryManager]] with the specified capacity and page size.
-   */
-  def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = {
-    val conf = new SparkConf
-    val memoryManager = new StaticMemoryManager(
-      conf, maxExecutionMemory = maxMemory, maxStorageMemory = Long.MaxValue)
-    new ShuffleMemoryManager(memoryManager, pageSizeBytes)
-  }
-
-  @VisibleForTesting
-  def createForTesting(maxMemory: Long): ShuffleMemoryManager = {
-    create(maxMemory, 4 * 1024 * 1024)
-  }
-
-  /**
-   * Sets the page size, in bytes.
-   *
-   * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value
-   * by looking at the number of cores available to the process, and the total amount of memory,
-   * and then divide it by a factor of safety.
-   */
-  private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = {
-    val minPageSize = 1L * 1024 * 1024   // 1MB
-    val maxPageSize = 64L * minPageSize  // 64MB
-    val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors()
-    // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case
-    val safetyFactor = 16
-    val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor)
-    val default = math.min(maxPageSize, math.max(minPageSize, size))
-    conf.getSizeAsBytes("spark.buffer.pageSize", default)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 1105167..66b6bbc 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -133,7 +133,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
           env.blockManager,
           shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
           context.taskMemoryManager(),
-          env.shuffleMemoryManager,
           unsafeShuffleHandle,
           mapId,
           context,

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index bbd9c1a..808317b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -52,13 +52,13 @@ private[spark] class SortShuffleWriter[K, V, C](
     sorter = if (dep.mapSideCombine) {
       require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
       new ExternalSorter[K, V, C](
-        dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
+        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
     } else {
       // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
       // care whether the keys get sorted in each partition; that will be done on the reduce side
       // if the operation being run is sortByKey.
       new ExternalSorter[K, V, V](
-        aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
+        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
     }
     sorter.insertAll(records)
 
@@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C](
     // (see SPARK-3570).
     val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
     val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
-    val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
+    val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
     shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
 
     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/3] spark git commit: [SPARK-10984] Simplify *MemoryManager class structure

Posted by jo...@apache.org.
http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index cfa58f5..f6d81ee 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -28,8 +28,10 @@ import com.google.common.io.ByteStreams
 
 import org.apache.spark.{Logging, SparkEnv, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.serializer.{DeserializationStream, Serializer}
 import org.apache.spark.storage.{BlockId, BlockManager}
+import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
 import org.apache.spark.executor.ShuffleWriteMetrics
 
@@ -55,12 +57,30 @@ class ExternalAppendOnlyMap[K, V, C](
     mergeValue: (C, V) => C,
     mergeCombiners: (C, C) => C,
     serializer: Serializer = SparkEnv.get.serializer,
-    blockManager: BlockManager = SparkEnv.get.blockManager)
+    blockManager: BlockManager = SparkEnv.get.blockManager,
+    context: TaskContext = TaskContext.get())
   extends Iterable[(K, C)]
   with Serializable
   with Logging
   with Spillable[SizeTracker] {
 
+  if (context == null) {
+    throw new IllegalStateException(
+      "Spillable collections should not be instantiated outside of tasks")
+  }
+
+  // Backwards-compatibility constructor for binary compatibility
+  def this(
+      createCombiner: V => C,
+      mergeValue: (C, V) => C,
+      mergeCombiners: (C, C) => C,
+      serializer: Serializer,
+      blockManager: BlockManager) {
+    this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get())
+  }
+
+  override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
+
   private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
   private val spilledMaps = new ArrayBuffer[DiskMapIterator]
   private val sparkConf = SparkEnv.get.conf
@@ -118,6 +138,10 @@ class ExternalAppendOnlyMap[K, V, C](
    * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked.
    */
   def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
+    if (currentMap == null) {
+      throw new IllegalStateException(
+        "Cannot insert new elements into a map after calling iterator")
+    }
     // An update function for the map that we reuse across entries to avoid allocating
     // a new closure each time
     var curEntry: Product2[K, V] = null
@@ -215,17 +239,26 @@ class ExternalAppendOnlyMap[K, V, C](
   }
 
   /**
-   * Return an iterator that merges the in-memory map with the spilled maps.
+   * Return a destructive iterator that merges the in-memory map with the spilled maps.
    * If no spill has occurred, simply return the in-memory map's iterator.
    */
   override def iterator: Iterator[(K, C)] = {
+    if (currentMap == null) {
+      throw new IllegalStateException(
+        "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
+    }
     if (spilledMaps.isEmpty) {
-      currentMap.iterator
+      CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())
     } else {
       new ExternalIterator()
     }
   }
 
+  private def freeCurrentMap(): Unit = {
+    currentMap = null // So that the memory can be garbage-collected
+    releaseMemory()
+  }
+
   /**
    * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps
    */
@@ -237,7 +270,8 @@ class ExternalAppendOnlyMap[K, V, C](
 
     // Input streams are derived both from the in-memory map and spilled maps on disk
     // The in-memory map is sorted in place, while the spilled maps are already in sorted order
-    private val sortedMap = currentMap.destructiveSortedIterator(keyComparator)
+    private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](
+      currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
     private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
 
     inputStreams.foreach { it =>
@@ -493,12 +527,7 @@ class ExternalAppendOnlyMap[K, V, C](
       }
     }
 
-    val context = TaskContext.get()
-    // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in
-    // a TaskContext.
-    if (context != null) {
-      context.addTaskCompletionListener(context => cleanup())
-    }
+    context.addTaskCompletionListener(context => cleanup())
   }
 
   /** Convenience function to hash the given (K, C) pair by the key. */

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index c48c453..a44e72b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -27,6 +27,7 @@ import com.google.common.annotations.VisibleForTesting
 import com.google.common.io.ByteStreams
 
 import org.apache.spark._
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.serializer._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
@@ -87,6 +88,7 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
  * - Users are expected to call stop() at the end to delete all the intermediate files.
  */
 private[spark] class ExternalSorter[K, V, C](
+    context: TaskContext,
     aggregator: Option[Aggregator[K, V, C]] = None,
     partitioner: Option[Partitioner] = None,
     ordering: Option[Ordering[K]] = None,
@@ -94,6 +96,8 @@ private[spark] class ExternalSorter[K, V, C](
   extends Logging
   with Spillable[WritablePartitionedPairCollection[K, C]] {
 
+  override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
+
   private val conf = SparkEnv.get.conf
 
   private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
@@ -640,7 +644,6 @@ private[spark] class ExternalSorter[K, V, C](
    */
   def writePartitionedFile(
       blockId: BlockId,
-      context: TaskContext,
       outputFile: File): Array[Long] = {
 
     // Track location of each range in the output file
@@ -686,8 +689,11 @@ private[spark] class ExternalSorter[K, V, C](
   }
 
   def stop(): Unit = {
+    map = null // So that the memory can be garbage-collected
+    buffer = null // So that the memory can be garbage-collected
     spills.foreach(s => s.file.delete())
     spills.clear()
+    releaseMemory()
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index d2a68ca..a76891a 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -17,8 +17,8 @@
 
 package org.apache.spark.util.collection
 
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.{Logging, SparkEnv}
 
 /**
  * Spills contents of an in-memory collection to disk when the memory threshold
@@ -40,7 +40,7 @@ private[spark] trait Spillable[C] extends Logging {
   protected def addElementsRead(): Unit = { _elementsRead += 1 }
 
   // Memory manager that can be used to acquire/release memory
-  private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
+  protected[this] def taskMemoryManager: TaskMemoryManager
 
   // Initial threshold for the size of a collection before we start tracking its memory usage
   // For testing only
@@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging {
     if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
       // Claim up to double our current memory from the shuffle memory pool
       val amountToRequest = 2 * currentMemory - myMemoryThreshold
-      val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+      val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest)
       myMemoryThreshold += granted
       // If we were granted too little memory to grow further (either tryToAcquire returned 0,
       // or we already had more memory than myMemoryThreshold), spill the current collection
@@ -92,7 +92,7 @@ private[spark] trait Spillable[C] extends Logging {
       spill(collection)
       _elementsRead = 0
       _memoryBytesSpilled += currentMemory
-      releaseMemoryForThisThread()
+      releaseMemory()
     }
     shouldSpill
   }
@@ -103,11 +103,11 @@ private[spark] trait Spillable[C] extends Logging {
   def memoryBytesSpilled: Long = _memoryBytesSpilled
 
   /**
-   * Release our memory back to the shuffle pool so that other threads can grab it.
+   * Release our memory back to the execution pool so that other tasks can grab it.
    */
-  private def releaseMemoryForThisThread(): Unit = {
+  def releaseMemory(): Unit = {
     // The amount we requested does not include the initial memory tracking threshold
-    shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold)
+    taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold)
     myMemoryThreshold = initialMemoryThreshold
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
new file mode 100644
index 0000000..f381db0
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+public class TaskMemoryManagerSuite {
+
+  @Test
+  public void leakedPageMemoryIsDetected() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    manager.allocatePage(4096);  // leak memory
+    Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
+  }
+
+  @Test
+  public void encodePageNumberAndOffsetOffHeap() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
+    // encode. This test exercises that corner-case:
+    final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset);
+    Assert.assertEquals(null, manager.getPage(encodedAddress));
+    Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress));
+  }
+
+  @Test
+  public void encodePageNumberAndOffsetOnHeap() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
+    Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
+    Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index 232ae4d..7fb2f92 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -21,18 +21,19 @@ import org.apache.spark.shuffle.sort.PackedRecordPointer;
 import org.junit.Test;
 import static org.junit.Assert.*;
 
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
 
 public class PackedRecordPointerSuite {
 
   @Test
   public void heap() {
+    final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
     final MemoryBlock page0 = memoryManager.allocatePage(128);
     final MemoryBlock page1 = memoryManager.allocatePage(128);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
@@ -49,8 +50,9 @@ public class PackedRecordPointerSuite {
 
   @Test
   public void offHeap() {
+    final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
     final MemoryBlock page0 = memoryManager.allocatePage(128);
     final MemoryBlock page1 = memoryManager.allocatePage(128);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 1ef3c5f..5049a53 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -24,11 +24,11 @@ import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.spark.HashPartitioner;
+import org.apache.spark.SparkConf;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 
 public class ShuffleInMemorySorterSuite {
 
@@ -58,8 +58,9 @@ public class ShuffleInMemorySorterSuite {
       "Lychee",
       "Mango"
     };
+    final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
     final MemoryBlock dataPage = memoryManager.allocatePage(2048);
     final Object baseObject = dataPage.getBaseObject();
     final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 29d9823..d659269 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -39,7 +39,6 @@ import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.lessThan;
 import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
 import static org.mockito.Answers.RETURNS_SMART_NULLS;
 import static org.mockito.Mockito.*;
 
@@ -54,19 +53,15 @@ import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.serializer.*;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
 import org.apache.spark.storage.*;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 public class UnsafeShuffleWriterSuite {
 
   static final int NUM_PARTITITONS = 4;
-  final TaskMemoryManager taskMemoryManager =
-    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+  TaskMemoryManager taskMemoryManager;
   final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
   File mergedOutputFile;
   File tempDir;
@@ -76,7 +71,6 @@ public class UnsafeShuffleWriterSuite {
   final Serializer serializer = new KryoSerializer(new SparkConf());
   TaskMetrics taskMetrics;
 
-  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
   @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@@ -111,11 +105,11 @@ public class UnsafeShuffleWriterSuite {
     mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
     partitionSizesInMergedFile = null;
     spillFilesCreated.clear();
-    conf = new SparkConf().set("spark.buffer.pageSize", "128m");
+    conf = new SparkConf()
+      .set("spark.buffer.pageSize", "128m")
+      .set("spark.unsafe.offHeap", "false");
     taskMetrics = new TaskMetrics();
-
-    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
-    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
+    taskMemoryManager =  new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
 
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
     when(blockManager.getDiskWriter(
@@ -203,7 +197,6 @@ public class UnsafeShuffleWriterSuite {
       blockManager,
       shuffleBlockResolver,
       taskMemoryManager,
-      shuffleMemoryManager,
       new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
       0, // map id
       taskContext,
@@ -405,11 +398,12 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void writeEnoughDataToTriggerSpill() throws Exception {
-    when(shuffleMemoryManager.tryToAcquire(anyLong()))
-      .then(returnsFirstArg()) // Allocate initial sort buffer
-      .then(returnsFirstArg()) // Allocate initial data page
-      .thenReturn(0L) // Deny request to allocate new data page
-      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    taskMemoryManager = spy(taskMemoryManager);
+    doCallRealMethod() // initialize sort buffer
+      .doCallRealMethod() // allocate initial data page
+      .doReturn(0L) // deny request to allocate new page
+      .doCallRealMethod() // grant new sort buffer and data page
+      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
     final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
@@ -417,7 +411,7 @@ public class UnsafeShuffleWriterSuite {
       dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
     }
     writer.write(dataToWrite.iterator());
-    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
     assertEquals(2, spillFilesCreated.size());
     writer.stop(true);
     readRecordsFromFile();
@@ -432,18 +426,19 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
-    when(shuffleMemoryManager.tryToAcquire(anyLong()))
-      .then(returnsFirstArg()) // Allocate initial sort buffer
-      .then(returnsFirstArg()) // Allocate initial data page
-      .thenReturn(0L) // Deny request to grow sort buffer
-      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    taskMemoryManager = spy(taskMemoryManager);
+    doCallRealMethod() // initialize sort buffer
+      .doCallRealMethod() // allocate initial data page
+      .doReturn(0L) // deny request to allocate new page
+      .doCallRealMethod() // grant new sort buffer and data page
+      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
     for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
       dataToWrite.add(new Tuple2<Object, Object>(i, i));
     }
     writer.write(dataToWrite.iterator());
-    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
     assertEquals(2, spillFilesCreated.size());
     writer.stop(true);
     readRecordsFromFile();
@@ -509,13 +504,13 @@ public class UnsafeShuffleWriterSuite {
     final long recordLengthBytes = 8;
     final long pageSizeBytes = 256;
     final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
-    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
+    taskMemoryManager = spy(taskMemoryManager);
+    when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
     final UnsafeShuffleWriter<Object, Object> writer =
       new UnsafeShuffleWriter<Object, Object>(
         blockManager,
         shuffleBlockResolver,
         taskMemoryManager,
-        shuffleMemoryManager,
         new SerializedShuffleHandle<>(0, 1, shuffleDep),
         0, // map id
         taskContext,

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index ab480b6..6e52496 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -21,15 +21,13 @@ import java.lang.Exception;
 import java.nio.ByteBuffer;
 import java.util.*;
 
+import org.apache.spark.memory.TaskMemoryManager;
 import org.junit.*;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.junit.Assert.*;
-import static org.mockito.AdditionalMatchers.geq;
-import static org.mockito.Mockito.*;
 
-import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.memory.*;
 import org.apache.spark.unsafe.Platform;
@@ -39,42 +37,29 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   private final Random rand = new Random(42);
 
-  private ShuffleMemoryManager shuffleMemoryManager;
+  private GrantEverythingMemoryManager memoryManager;
   private TaskMemoryManager taskMemoryManager;
-  private TaskMemoryManager sizeLimitedTaskMemoryManager;
   private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
 
   @Before
   public void setup() {
-    shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, PAGE_SIZE_BYTES);
-    taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
-    // Mocked memory manager for tests that check the maximum array size, since actually allocating
-    // such large arrays will cause us to run out of memory in our tests.
-    sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class);
-    when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer(
-      new Answer<MemoryBlock>() {
-        @Override
-        public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
-          if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
-            throw new OutOfMemoryError("Requested array size exceeds VM limit");
-          }
-          return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]);
-        }
-      }
-    );
+    memoryManager =
+      new GrantEverythingMemoryManager(
+        new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()));
+    taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
   }
 
   @After
   public void tearDown() {
     Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
-    if (shuffleMemoryManager != null) {
-      long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
-      shuffleMemoryManager = null;
-      Assert.assertEquals(0L, leakedShuffleMemory);
+    if (taskMemoryManager != null) {
+      long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
+      taskMemoryManager = null;
+      Assert.assertEquals(0L, leakedMemory);
     }
   }
 
-  protected abstract MemoryAllocator getMemoryAllocator();
+  protected abstract boolean useOffHeapMemoryAllocator();
 
   private static byte[] getByteArray(MemoryLocation loc, int size) {
     final byte[] arr = new byte[size];
@@ -110,8 +95,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void emptyMap() {
-    BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
     try {
       Assert.assertEquals(0, map.numElements());
       final int keyLengthInWords = 10;
@@ -126,8 +110,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void setAndRetrieveAKey() {
-    BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
     final int recordLengthWords = 10;
     final int recordLengthBytes = recordLengthWords * 8;
     final byte[] keyData = getRandomByteArray(recordLengthWords);
@@ -179,8 +162,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   private void iteratorTestBase(boolean destructive) throws Exception {
     final int size = 4096;
-    BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES);
     try {
       for (long i = 0; i < size; i++) {
         final long[] value = new long[] { i };
@@ -265,8 +247,8 @@ public abstract class AbstractBytesToBytesMapSuite {
     final int NUM_ENTRIES = 1000 * 1000;
     final int KEY_LENGTH = 24;
     final int VALUE_LENGTH = 40;
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
+    final BytesToBytesMap map =
+      new BytesToBytesMap(taskMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
     // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte
     // pages won't be evenly-divisible by records of this size, which will cause us to waste some
     // space at the end of the page. This is necessary in order for us to take the end-of-record
@@ -335,9 +317,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
     // into ByteBuffers in order to use them as keys here.
     final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES);
-
+    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size, PAGE_SIZE_BYTES);
     try {
       // Fill the map to 90% full so that we can trigger probing
       for (int i = 0; i < size * 0.9; i++) {
@@ -386,8 +366,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void randomizedTestWithRecordsLargerThanPageSize() {
     final long pageSizeBytes = 128;
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes);
+    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, pageSizeBytes);
     // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
     // into ByteBuffers in order to use them as keys here.
     final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
@@ -436,9 +415,9 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void failureToAllocateFirstPage() {
-    shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024);
-    BytesToBytesMap map =
-      new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+    memoryManager.markExecutionAsOutOfMemory();
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
+    memoryManager.markExecutionAsOutOfMemory();
     try {
       final long[] emptyArray = new long[0];
       final BytesToBytesMap.Location loc =
@@ -454,12 +433,14 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void failureToGrow() {
-    shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024 * 10);
-    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, 1024);
     try {
       boolean success = true;
       int i;
-      for (i = 0; i < 1024; i++) {
+      for (i = 0; i < 127; i++) {
+        if (i > 0) {
+          memoryManager.markExecutionAsOutOfMemory();
+        }
         final long[] arr = new long[]{i};
         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
         success =
@@ -478,7 +459,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void initialCapacityBoundsChecking() {
     try {
-      new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES);
+      new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES);
       Assert.fail("Expected IllegalArgumentException to be thrown");
     } catch (IllegalArgumentException e) {
       // expected exception
@@ -486,36 +467,13 @@ public abstract class AbstractBytesToBytesMapSuite {
 
     try {
       new BytesToBytesMap(
-        sizeLimitedTaskMemoryManager,
-        shuffleMemoryManager,
+        taskMemoryManager,
         BytesToBytesMap.MAX_CAPACITY + 1,
         PAGE_SIZE_BYTES);
       Assert.fail("Expected IllegalArgumentException to be thrown");
     } catch (IllegalArgumentException e) {
       // expected exception
     }
-
-    // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
-    // Can allocate _at_ the max capacity
-    //    BytesToBytesMap map = new BytesToBytesMap(
-    //      sizeLimitedTaskMemoryManager,
-    //      shuffleMemoryManager,
-    //      BytesToBytesMap.MAX_CAPACITY,
-    //      PAGE_SIZE_BYTES);
-    //    map.free();
-  }
-
-  // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
-  @Ignore
-  public void resizingLargeMap() {
-    // As long as a map's capacity is below the max, we should be able to resize up to the max
-    BytesToBytesMap map = new BytesToBytesMap(
-      sizeLimitedTaskMemoryManager,
-      shuffleMemoryManager,
-      BytesToBytesMap.MAX_CAPACITY - 64,
-      PAGE_SIZE_BYTES);
-    map.growAndRehash();
-    map.free();
   }
 
   @Test
@@ -523,8 +481,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     final long recordLengthBytes = 24;
     final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
     final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes);
+    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes);
 
     // Since BytesToBytesMap is append-only, we expect the total memory consumption to be
     // monotonically increasing. More specifically, every time we allocate a new page it
@@ -564,8 +521,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void testAcquirePageInConstructor() {
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
     assertEquals(1, map.getNumDataPages());
     map.free();
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
index 5a10de4..f0bad4d 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
@@ -17,13 +17,10 @@
 
 package org.apache.spark.unsafe.map;
 
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-
 public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite {
 
   @Override
-  protected MemoryAllocator getMemoryAllocator() {
-    return MemoryAllocator.UNSAFE;
+  protected boolean useOffHeapMemoryAllocator() {
+    return true;
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
index 12cc9b2..d76bb4f 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
@@ -17,13 +17,10 @@
 
 package org.apache.spark.unsafe.map;
 
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-
 public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite {
 
   @Override
-  protected MemoryAllocator getMemoryAllocator() {
-    return MemoryAllocator.HEAP;
+  protected boolean useOffHeapMemoryAllocator() {
+    return false;
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index a5bbaa9..94d50b9 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -46,20 +46,19 @@ import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.storage.*;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 public class UnsafeExternalSorterSuite {
 
   final LinkedList<File> spillFilesCreated = new LinkedList<File>();
-  final TaskMemoryManager taskMemoryManager =
-    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+  final GrantEverythingMemoryManager memoryManager =
+    new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
+  final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
   // Use integer comparison for comparing prefixes (which are partition ids, in this case)
   final PrefixComparator prefixComparator = new PrefixComparator() {
     @Override
@@ -82,7 +81,6 @@ public class UnsafeExternalSorterSuite {
 
   SparkConf sparkConf;
   File tempDir;
-  ShuffleMemoryManager shuffleMemoryManager;
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
@@ -102,7 +100,6 @@ public class UnsafeExternalSorterSuite {
     MockitoAnnotations.initMocks(this);
     sparkConf = new SparkConf();
     tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
-    shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, pageSizeBytes);
     spillFilesCreated.clear();
     taskContext = mock(TaskContext.class);
     when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
@@ -143,13 +140,7 @@ public class UnsafeExternalSorterSuite {
   @After
   public void tearDown() {
     try {
-      long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
-      if (shuffleMemoryManager != null) {
-        long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
-        shuffleMemoryManager = null;
-        assertEquals(0L, leakedShuffleMemory);
-      }
-      assertEquals(0, leakedUnsafeMemory);
+      assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
     } finally {
       Utils.deleteRecursively(tempDir);
       tempDir = null;
@@ -178,7 +169,6 @@ public class UnsafeExternalSorterSuite {
   private UnsafeExternalSorter newSorter() throws IOException {
     return UnsafeExternalSorter.create(
       taskMemoryManager,
-      shuffleMemoryManager,
       blockManager,
       taskContext,
       recordComparator,
@@ -236,12 +226,16 @@ public class UnsafeExternalSorterSuite {
 
   @Test
   public void spillingOccursInResponseToMemoryPressure() throws Exception {
-    shuffleMemoryManager = ShuffleMemoryManager.create(pageSizeBytes * 2, pageSizeBytes);
     final UnsafeExternalSorter sorter = newSorter();
-    final int numRecords = (int) pageSizeBytes / 4;
-    for (int i = 0; i <= numRecords; i++) {
+    // This should be enough records to completely fill up a data page:
+    final int numRecords = (int) (pageSizeBytes / (4 + 4));
+    for (int i = 0; i < numRecords; i++) {
       insertNumber(sorter, numRecords - i);
     }
+    assertEquals(1, sorter.getNumberOfAllocatedPages());
+    memoryManager.markExecutionAsOutOfMemory();
+    // The insertion of this record should trigger a spill:
+    insertNumber(sorter, 0);
     // Ensure that spill files were created
     assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1));
     // Read back the sorted data:
@@ -255,6 +249,7 @@ public class UnsafeExternalSorterSuite {
       assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()));
       i++;
     }
+    assertEquals(numRecords + 1, i);
     sorter.cleanupResources();
     assertSpillFilesWereCleanedUp();
   }
@@ -323,7 +318,6 @@ public class UnsafeExternalSorterSuite {
     final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
     final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
       taskMemoryManager,
-      shuffleMemoryManager,
       blockManager,
       taskContext,
       recordComparator,

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 778e813..d5de56a 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -26,11 +26,11 @@ import static org.junit.Assert.*;
 import static org.mockito.Mockito.mock;
 
 import org.apache.spark.HashPartitioner;
+import org.apache.spark.SparkConf;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 
 public class UnsafeInMemorySorterSuite {
 
@@ -43,7 +43,8 @@ public class UnsafeInMemorySorterSuite {
   @Test
   public void testSortingEmptyInput() {
     final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
+      new TaskMemoryManager(
+        new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
       mock(RecordComparator.class),
       mock(PrefixComparator.class),
       100);
@@ -64,8 +65,8 @@ public class UnsafeInMemorySorterSuite {
       "Lychee",
       "Mango"
     };
-    final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final TaskMemoryManager memoryManager = new TaskMemoryManager(
+      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
     final MemoryBlock dataPage = memoryManager.allocatePage(2048);
     final Object baseObject = dataPage.getBaseObject();
     // Write the records into the data page:

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/FailureSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index f58756e..0242cbc 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // cause is preserved
     val thrownDueToTaskFailure = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocate(128)
+        TaskContext.get().taskMemoryManager().allocatePage(128)
         throw new Exception("intentional task failure")
         iter
       }.count()
@@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // If the task succeeded but memory was leaked, then the task should fail due to that leak
     val thrownDueToMemoryLeak = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocate(128)
+        TaskContext.get().taskMemoryManager().allocatePage(128)
         iter
       }.count()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
new file mode 100644
index 0000000..fe102d8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.{BlockStatus, BlockId}
+
+class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
+  private[memory] override def doAcquireExecutionMemory(
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
+    if (oom) {
+      oom = false
+      0
+    } else {
+      _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory
+      numBytes
+    }
+  }
+  override def acquireStorageMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
+  override def acquireUnrollMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
+  override def releaseStorageMemory(numBytes: Long): Unit = { }
+  override def maxExecutionMemory: Long = Long.MaxValue
+  override def maxStorageMemory: Long = Long.MaxValue
+
+  private var oom = false
+
+  def markExecutionAsOutOfMemory(): Unit = {
+    oom = true
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
index 36e4566..1265087 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
@@ -19,10 +19,14 @@ package org.apache.spark.memory
 
 import java.util.concurrent.atomic.AtomicLong
 
+import scala.concurrent.duration.Duration
+import scala.concurrent.{Await, ExecutionContext, Future}
+
 import org.mockito.Matchers.{any, anyLong}
 import org.mockito.Mockito.{mock, when}
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.storage.MemoryStore
@@ -126,6 +130,136 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED,
       "ensure free space should not have been called!")
   }
+
+  /**
+   * Create a MemoryManager with the specified execution memory limit and no storage memory.
+   */
+  protected def createMemoryManager(maxExecutionMemory: Long): MemoryManager
+
+  // -- Tests of sharing of execution memory between tasks ----------------------------------------
+  // Prior to Spark 1.6, these tests were part of ShuffleMemoryManagerSuite.
+
+  implicit val ec = ExecutionContext.global
+
+  test("single task requesting execution memory") {
+    val manager = createMemoryManager(1000L)
+    val taskMemoryManager = new TaskMemoryManager(manager, 0)
+
+    assert(taskMemoryManager.acquireExecutionMemory(100L) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(200L) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+
+    taskMemoryManager.releaseExecutionMemory(500L)
+    assert(taskMemoryManager.acquireExecutionMemory(300L) === 300L)
+    assert(taskMemoryManager.acquireExecutionMemory(300L) === 200L)
+
+    taskMemoryManager.cleanUpAllAllocatedMemory()
+    assert(taskMemoryManager.acquireExecutionMemory(1000L) === 1000L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+  }
+
+  test("two tasks requesting full execution memory") {
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    // Have both tasks request 500 bytes, then wait until both requests have been granted:
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t1Result1, futureTimeout) === 500L)
+    assert(Await.result(t2Result1, futureTimeout) === 500L)
+
+    // Have both tasks each request 500 bytes more; both should immediately return 0 as they are
+    // both now at 1 / N
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t1Result2, 200.millis) === 0L)
+    assert(Await.result(t2Result2, 200.millis) === 0L)
+  }
+
+  test("two tasks cannot grow past 1 / N of execution memory") {
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    // Have both tasks request 250 bytes, then wait until both requests have been granted:
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+    assert(Await.result(t1Result1, futureTimeout) === 250L)
+    assert(Await.result(t2Result1, futureTimeout) === 250L)
+
+    // Have both tasks each request 500 bytes more.
+    // We should only grant 250 bytes to each of them on this second request
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t1Result2, futureTimeout) === 250L)
+    assert(Await.result(t2Result2, futureTimeout) === 250L)
+  }
+
+  test("tasks can block to get at least 1 / 2N of execution memory") {
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+    assert(Await.result(t1Result1, futureTimeout) === 1000L)
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+    // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
+    // to make sure the other thread blocks for some time otherwise.
+    Thread.sleep(300)
+    t1MemManager.releaseExecutionMemory(250L)
+    // The memory freed from t1 should now be granted to t2.
+    assert(Await.result(t2Result1, futureTimeout) === 250L)
+    // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory.
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L) }
+    assert(Await.result(t2Result2, 200.millis) === 0L)
+  }
+
+  test("TaskMemoryManager.cleanUpAllAllocatedMemory") {
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+    assert(Await.result(t1Result1, futureTimeout) === 1000L)
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
+    // to make sure the other thread blocks for some time otherwise.
+    Thread.sleep(300)
+    // t1 releases all of its memory, so t2 should be able to grab all of the memory
+    t1MemManager.cleanUpAllAllocatedMemory()
+    assert(Await.result(t2Result1, futureTimeout) === 500L)
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t2Result2, futureTimeout) === 500L)
+    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t2Result3, 200.millis) === 0L)
+  }
+
+  test("tasks should not be granted a negative amount of execution memory") {
+    // This is a regression test for SPARK-4715.
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L) }
+    assert(Await.result(t1Result1, futureTimeout) === 700L)
+
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L) }
+    assert(Await.result(t2Result1, futureTimeout) === 300L)
+
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L) }
+    assert(Await.result(t1Result2, 200.millis) === 0L)
+  }
 }
 
 private object MemoryManagerSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
new file mode 100644
index 0000000..4b4c3b0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
+
+/**
+ * Helper methods for mocking out memory-management-related classes in tests.
+ */
+object MemoryTestingUtils {
+  def fakeTaskContext(env: SparkEnv): TaskContext = {
+    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
+    new TaskContextImpl(
+      stageId = 0,
+      partitionId = 0,
+      taskAttemptId = 0,
+      attemptNumber = 0,
+      taskMemoryManager = taskMemoryManager,
+      metricsSystem = env.metricsSystem,
+      internalAccumulators = Seq.empty)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
index 6cae1f8..885c450 100644
--- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
@@ -36,27 +36,35 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite {
       maxExecutionMem: Long,
       maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = {
     val mm = new StaticMemoryManager(
-      conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem)
+      conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem, numCores = 1)
     val ms = makeMemoryStore(mm)
     (mm, ms)
   }
 
+  override protected def createMemoryManager(maxMemory: Long): MemoryManager = {
+    new StaticMemoryManager(
+      conf,
+      maxExecutionMemory = maxMemory,
+      maxStorageMemory = 0,
+      numCores = 1)
+  }
+
   test("basic execution memory") {
     val maxExecutionMem = 1000L
     val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue)
     assert(mm.executionMemoryUsed === 0L)
-    assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L)
+    assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L)
     assert(mm.executionMemoryUsed === 10L)
-    assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
     // Acquire up to the max
-    assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L)
+    assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L)
     assert(mm.executionMemoryUsed === maxExecutionMem)
-    assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L)
+    assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L)
     assert(mm.executionMemoryUsed === maxExecutionMem)
     mm.releaseExecutionMemory(800L)
     assert(mm.executionMemoryUsed === 200L)
     // Acquire after release
-    assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L)
+    assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L)
     assert(mm.executionMemoryUsed === 201L)
     // Release beyond what was acquired
     mm.releaseExecutionMemory(maxExecutionMem)
@@ -108,10 +116,10 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite {
     val dummyBlock = TestBlockId("ain't nobody love like you do")
     val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem)
     // Only execution memory should increase
-    assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
     assert(mm.storageMemoryUsed === 0L)
     assert(mm.executionMemoryUsed === 100L)
-    assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 100L)
     assert(mm.storageMemoryUsed === 0L)
     assert(mm.executionMemoryUsed === 200L)
     // Only storage memory should increase

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
index e7baa50..0c97f2b 100644
--- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
@@ -34,11 +34,15 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
    * Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies.
    */
   private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = {
-    val mm = new UnifiedMemoryManager(conf, maxMemory)
+    val mm = new UnifiedMemoryManager(conf, maxMemory, numCores = 1)
     val ms = makeMemoryStore(mm)
     (mm, ms)
   }
 
+  override protected def createMemoryManager(maxMemory: Long): MemoryManager = {
+    new UnifiedMemoryManager(conf, maxMemory, numCores = 1)
+  }
+
   private def getStorageRegionSize(mm: UnifiedMemoryManager): Long = {
     mm invokePrivate PrivateMethod[Long]('storageRegionSize)()
   }
@@ -56,18 +60,18 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
     val maxMemory = 1000L
     val (mm, _) = makeThings(maxMemory)
     assert(mm.executionMemoryUsed === 0L)
-    assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L)
+    assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L)
     assert(mm.executionMemoryUsed === 10L)
-    assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
     // Acquire up to the max
-    assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L)
+    assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L)
     assert(mm.executionMemoryUsed === maxMemory)
-    assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L)
+    assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L)
     assert(mm.executionMemoryUsed === maxMemory)
     mm.releaseExecutionMemory(800L)
     assert(mm.executionMemoryUsed === 200L)
     // Acquire after release
-    assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L)
+    assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L)
     assert(mm.executionMemoryUsed === 201L)
     // Release beyond what was acquired
     mm.releaseExecutionMemory(maxMemory)
@@ -132,12 +136,12 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
     require(mm.storageMemoryUsed > storageRegionSize,
       s"bad test: storage memory used should exceed the storage region")
     // Execution needs to request 250 bytes to evict storage memory
-    assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
     assert(mm.executionMemoryUsed === 100L)
     assert(mm.storageMemoryUsed === 750L)
     assertEnsureFreeSpaceNotCalled(ms)
     // Execution wants 200 bytes but only 150 are free, so storage is evicted
-    assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L)
+    assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L)
     assertEnsureFreeSpaceCalled(ms, 200L)
     assert(mm.executionMemoryUsed === 300L)
     mm.releaseAllStorageMemory()
@@ -151,7 +155,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
       s"bad test: storage memory used should be within the storage region")
     // Execution cannot evict storage because the latter is within the storage fraction,
     // so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300
-    assert(mm.acquireExecutionMemory(400L, evictedBlocks) === 300L)
+    assert(mm.doAcquireExecutionMemory(400L, evictedBlocks) === 300L)
     assert(mm.executionMemoryUsed === 600L)
     assert(mm.storageMemoryUsed === 400L)
     assertEnsureFreeSpaceNotCalled(ms)
@@ -170,7 +174,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
     require(executionRegionSize === expectedExecutionRegionSize,
       "bad test: storage region size is unexpected")
     // Acquire enough execution memory to exceed the execution region
-    assert(mm.acquireExecutionMemory(800L, evictedBlocks) === 800L)
+    assert(mm.doAcquireExecutionMemory(800L, evictedBlocks) === 800L)
     assert(mm.executionMemoryUsed === 800L)
     assert(mm.storageMemoryUsed === 0L)
     assertEnsureFreeSpaceNotCalled(ms)
@@ -188,7 +192,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
     mm.releaseExecutionMemory(maxMemory)
     mm.releaseStorageMemory(maxMemory)
     // Acquire some execution memory again, but this time keep it within the execution region
-    assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L)
+    assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L)
     assert(mm.executionMemoryUsed === 200L)
     assert(mm.storageMemoryUsed === 0L)
     assertEnsureFreeSpaceNotCalled(ms)

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
deleted file mode 100644
index 5877aa0..0000000
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
+++ /dev/null
@@ -1,326 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle
-
-import java.util.concurrent.CountDownLatch
-import java.util.concurrent.atomic.AtomicInteger
-
-import org.mockito.Mockito._
-import org.scalatest.concurrent.Timeouts
-import org.scalatest.time.SpanSugar._
-
-import org.apache.spark.{SparkFunSuite, TaskContext}
-import org.apache.spark.executor.TaskMetrics
-
-class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
-
-  val nextTaskAttemptId = new AtomicInteger()
-
-  /** Launch a thread with the given body block and return it. */
-  private def startThread(name: String)(body: => Unit): Thread = {
-    val thread = new Thread("ShuffleMemorySuite " + name) {
-      override def run() {
-        try {
-          val taskAttemptId = nextTaskAttemptId.getAndIncrement
-          val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS)
-          val taskMetrics = new TaskMetrics
-          when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId)
-          when(mockTaskContext.taskMetrics()).thenReturn(taskMetrics)
-          TaskContext.setTaskContext(mockTaskContext)
-          body
-        } finally {
-          TaskContext.unset()
-        }
-      }
-    }
-    thread.start()
-    thread
-  }
-
-  test("single task requesting memory") {
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    assert(manager.tryToAcquire(100L) === 100L)
-    assert(manager.tryToAcquire(400L) === 400L)
-    assert(manager.tryToAcquire(400L) === 400L)
-    assert(manager.tryToAcquire(200L) === 100L)
-    assert(manager.tryToAcquire(100L) === 0L)
-    assert(manager.tryToAcquire(100L) === 0L)
-
-    manager.release(500L)
-    assert(manager.tryToAcquire(300L) === 300L)
-    assert(manager.tryToAcquire(300L) === 200L)
-
-    manager.releaseMemoryForThisTask()
-    assert(manager.tryToAcquire(1000L) === 1000L)
-    assert(manager.tryToAcquire(100L) === 0L)
-  }
-
-  test("two threads requesting full memory") {
-    // Two threads request 500 bytes first, wait for each other to get it, and then request
-    // 500 more; we should immediately return 0 as both are now at 1 / N
-
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    class State {
-      var t1Result1 = -1L
-      var t2Result1 = -1L
-      var t1Result2 = -1L
-      var t2Result2 = -1L
-    }
-    val state = new State
-
-    val t1 = startThread("t1") {
-      val r1 = manager.tryToAcquire(500L)
-      state.synchronized {
-        state.t1Result1 = r1
-        state.notifyAll()
-        while (state.t2Result1 === -1L) {
-          state.wait()
-        }
-      }
-      val r2 = manager.tryToAcquire(500L)
-      state.synchronized { state.t1Result2 = r2 }
-    }
-
-    val t2 = startThread("t2") {
-      val r1 = manager.tryToAcquire(500L)
-      state.synchronized {
-        state.t2Result1 = r1
-        state.notifyAll()
-        while (state.t1Result1 === -1L) {
-          state.wait()
-        }
-      }
-      val r2 = manager.tryToAcquire(500L)
-      state.synchronized { state.t2Result2 = r2 }
-    }
-
-    failAfter(20 seconds) {
-      t1.join()
-      t2.join()
-    }
-
-    assert(state.t1Result1 === 500L)
-    assert(state.t2Result1 === 500L)
-    assert(state.t1Result2 === 0L)
-    assert(state.t2Result2 === 0L)
-  }
-
-
-  test("tasks cannot grow past 1 / N") {
-    // Two tasks request 250 bytes first, wait for each other to get it, and then request
-    // 500 more; we should only grant 250 bytes to each of them on this second request
-
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    class State {
-      var t1Result1 = -1L
-      var t2Result1 = -1L
-      var t1Result2 = -1L
-      var t2Result2 = -1L
-    }
-    val state = new State
-
-    val t1 = startThread("t1") {
-      val r1 = manager.tryToAcquire(250L)
-      state.synchronized {
-        state.t1Result1 = r1
-        state.notifyAll()
-        while (state.t2Result1 === -1L) {
-          state.wait()
-        }
-      }
-      val r2 = manager.tryToAcquire(500L)
-      state.synchronized { state.t1Result2 = r2 }
-    }
-
-    val t2 = startThread("t2") {
-      val r1 = manager.tryToAcquire(250L)
-      state.synchronized {
-        state.t2Result1 = r1
-        state.notifyAll()
-        while (state.t1Result1 === -1L) {
-          state.wait()
-        }
-      }
-      val r2 = manager.tryToAcquire(500L)
-      state.synchronized { state.t2Result2 = r2 }
-    }
-
-    failAfter(20 seconds) {
-      t1.join()
-      t2.join()
-    }
-
-    assert(state.t1Result1 === 250L)
-    assert(state.t2Result1 === 250L)
-    assert(state.t1Result2 === 250L)
-    assert(state.t2Result2 === 250L)
-  }
-
-  test("tasks can block to get at least 1 / 2N memory") {
-    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
-    // for a bit and releases 250 bytes, which should then be granted to t2. Further requests
-    // by t2 will return false right away because it now has 1 / 2N of the memory.
-
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    class State {
-      var t1Requested = false
-      var t2Requested = false
-      var t1Result = -1L
-      var t2Result = -1L
-      var t2Result2 = -1L
-      var t2WaitTime = 0L
-    }
-    val state = new State
-
-    val t1 = startThread("t1") {
-      state.synchronized {
-        state.t1Result = manager.tryToAcquire(1000L)
-        state.t1Requested = true
-        state.notifyAll()
-        while (!state.t2Requested) {
-          state.wait()
-        }
-      }
-      // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
-      // sure the other thread blocks for some time otherwise
-      Thread.sleep(300)
-      manager.release(250L)
-    }
-
-    val t2 = startThread("t2") {
-      state.synchronized {
-        while (!state.t1Requested) {
-          state.wait()
-        }
-        state.t2Requested = true
-        state.notifyAll()
-      }
-      val startTime = System.currentTimeMillis()
-      val result = manager.tryToAcquire(250L)
-      val endTime = System.currentTimeMillis()
-      state.synchronized {
-        state.t2Result = result
-        // A second call should return 0 because we're now already at 1 / 2N
-        state.t2Result2 = manager.tryToAcquire(100L)
-        state.t2WaitTime = endTime - startTime
-      }
-    }
-
-    failAfter(20 seconds) {
-      t1.join()
-      t2.join()
-    }
-
-    // Both threads should've been able to acquire their memory; the second one will have waited
-    // until the first one acquired 1000 bytes and then released 250
-    state.synchronized {
-      assert(state.t1Result === 1000L, "t1 could not allocate memory")
-      assert(state.t2Result === 250L, "t2 could not allocate memory")
-      assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
-      assert(state.t2Result2 === 0L, "t1 got extra memory the second time")
-    }
-  }
-
-  test("releaseMemoryForThisTask") {
-    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
-    // for a bit and releases all its memory. t2 should now be able to grab all the memory.
-
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    class State {
-      var t1Requested = false
-      var t2Requested = false
-      var t1Result = -1L
-      var t2Result1 = -1L
-      var t2Result2 = -1L
-      var t2Result3 = -1L
-      var t2WaitTime = 0L
-    }
-    val state = new State
-
-    val t1 = startThread("t1") {
-      state.synchronized {
-        state.t1Result = manager.tryToAcquire(1000L)
-        state.t1Requested = true
-        state.notifyAll()
-        while (!state.t2Requested) {
-          state.wait()
-        }
-      }
-      // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
-      // sure the other task blocks for some time otherwise
-      Thread.sleep(300)
-      manager.releaseMemoryForThisTask()
-    }
-
-    val t2 = startThread("t2") {
-      state.synchronized {
-        while (!state.t1Requested) {
-          state.wait()
-        }
-        state.t2Requested = true
-        state.notifyAll()
-      }
-      val startTime = System.currentTimeMillis()
-      val r1 = manager.tryToAcquire(500L)
-      val endTime = System.currentTimeMillis()
-      val r2 = manager.tryToAcquire(500L)
-      val r3 = manager.tryToAcquire(500L)
-      state.synchronized {
-        state.t2Result1 = r1
-        state.t2Result2 = r2
-        state.t2Result3 = r3
-        state.t2WaitTime = endTime - startTime
-      }
-    }
-
-    failAfter(20 seconds) {
-      t1.join()
-      t2.join()
-    }
-
-    // Both tasks should've been able to acquire their memory; the second one will have waited
-    // until the first one acquired 1000 bytes and then released all of it
-    state.synchronized {
-      assert(state.t1Result === 1000L, "t1 could not allocate memory")
-      assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time")
-      assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time")
-      assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})")
-      assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
-    }
-  }
-
-  test("tasks should not be granted a negative size") {
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-    manager.tryToAcquire(700L)
-
-    val latch = new CountDownLatch(1)
-    startThread("t1") {
-      manager.tryToAcquire(300L)
-      latch.countDown()
-    }
-    latch.await() // Wait until `t1` calls `tryToAcquire`
-
-    val granted = manager.tryToAcquire(300L)
-    assert(0 === granted, "granted is negative")
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index cc44c67..6e3f500 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -61,7 +61,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
       maxMem: Long,
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
-    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
     val store = new BlockManager(name, rpcEnv, master, serializer, conf,
       memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
     memManager.setMemoryStore(store.memoryStore)
@@ -261,7 +261,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
     val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work
     when(failableTransfer.hostName).thenReturn("some-hostname")
     when(failableTransfer.port).thenReturn(1000)
-    val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000)
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1)
     val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf,
       memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0)
     memManager.setMemoryStore(failableStore.memoryStore)

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index f3fab33..d49015a 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -68,7 +68,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       maxMem: Long,
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
-    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
     val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf,
       memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
     memManager.setMemoryStore(blockManager.memoryStore)
@@ -823,7 +823,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
   test("block store put failure") {
     // Use Java serializer so we can create an unserializable error.
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
-    val memoryManager = new StaticMemoryManager(conf, Long.MaxValue, 1200)
+    val memoryManager = new StaticMemoryManager(
+      conf,
+      maxExecutionMemory = Long.MaxValue,
+      maxStorageMemory = 1200,
+      numCores = 1)
     store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
       new JavaSerializer(conf), conf, memoryManager, mapOutputTracker,
       shuffleManager, transfer, securityMgr, 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 5cb506e..dc3185a 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
 import org.apache.spark.io.CompressionCodec
-
+import org.apache.spark.memory.MemoryTestingUtils
 
 class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
   import TestUtils.{assertNotSpilled, assertSpilled}
@@ -32,8 +32,11 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
   private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] =
     buf1 ++= buf2
 
-  private def createExternalMap[T] = new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]](
-    createCombiner[T], mergeValue[T], mergeCombiners[T])
+  private def createExternalMap[T] = {
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
+    new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]](
+      createCombiner[T], mergeValue[T], mergeCombiners[T], context = context)
+  }
 
   private def createSparkConf(loadDefaults: Boolean, codec: Option[String] = None): SparkConf = {
     val conf = new SparkConf(loadDefaults)
@@ -49,23 +52,27 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
     conf
   }
 
-  test("simple insert") {
+  test("single insert insert") {
     val conf = createSparkConf(loadDefaults = false)
     sc = new SparkContext("local", "test", conf)
     val map = createExternalMap[Int]
-
-    // Single insert
     map.insert(1, 10)
-    var it = map.iterator
+    val it = map.iterator
     assert(it.hasNext)
     val kv = it.next()
     assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10))
     assert(!it.hasNext)
+    sc.stop()
+  }
 
-    // Multiple insert
+  test("multiple insert") {
+    val conf = createSparkConf(loadDefaults = false)
+    sc = new SparkContext("local", "test", conf)
+    val map = createExternalMap[Int]
+    map.insert(1, 10)
     map.insert(2, 20)
     map.insert(3, 30)
-    it = map.iterator
+    val it = map.iterator
     assert(it.hasNext)
     assert(it.toSet === Set[(Int, ArrayBuffer[Int])](
       (1, ArrayBuffer[Int](10)),
@@ -144,39 +151,22 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
     sc = new SparkContext("local", "test", conf)
 
     val map = createExternalMap[Int]
+    val nullInt = null.asInstanceOf[Int]
     map.insert(1, 5)
     map.insert(2, 6)
     map.insert(3, 7)
-    assert(map.size === 3)
-    assert(map.iterator.toSet === Set[(Int, Seq[Int])](
-      (1, Seq[Int](5)),
-      (2, Seq[Int](6)),
-      (3, Seq[Int](7))
-    ))
-
-    // Null keys
-    val nullInt = null.asInstanceOf[Int]
+    map.insert(4, nullInt)
     map.insert(nullInt, 8)
-    assert(map.size === 4)
-    assert(map.iterator.toSet === Set[(Int, Seq[Int])](
+    map.insert(nullInt, nullInt)
+      val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.sorted))
+    assert(result === Set[(Int, Seq[Int])](
       (1, Seq[Int](5)),
       (2, Seq[Int](6)),
       (3, Seq[Int](7)),
-      (nullInt, Seq[Int](8))
+      (4, Seq[Int](nullInt)),
+      (nullInt, Seq[Int](nullInt, 8))
     ))
 
-    // Null values
-    map.insert(4, nullInt)
-    map.insert(nullInt, nullInt)
-    assert(map.size === 5)
-    val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
-    assert(result === Set[(Int, Set[Int])](
-      (1, Set[Int](5)),
-      (2, Set[Int](6)),
-      (3, Set[Int](7)),
-      (4, Set[Int](nullInt)),
-      (nullInt, Set[Int](nullInt, 8))
-    ))
     sc.stop()
   }
 
@@ -344,7 +334,9 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
     val conf = createSparkConf(loadDefaults = true)
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
-    val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
+    val map =
+      new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _, context = context)
 
     // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
     // problems if the map fails to group together the objects with the same code (SPARK-2043).


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org