You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/10/26 05:20:04 UTC

[2/3] spark git commit: [SPARK-10984] Simplify *MemoryManager class structure

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index cfa58f5..f6d81ee 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -28,8 +28,10 @@ import com.google.common.io.ByteStreams
 
 import org.apache.spark.{Logging, SparkEnv, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.serializer.{DeserializationStream, Serializer}
 import org.apache.spark.storage.{BlockId, BlockManager}
+import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
 import org.apache.spark.executor.ShuffleWriteMetrics
 
@@ -55,12 +57,30 @@ class ExternalAppendOnlyMap[K, V, C](
     mergeValue: (C, V) => C,
     mergeCombiners: (C, C) => C,
     serializer: Serializer = SparkEnv.get.serializer,
-    blockManager: BlockManager = SparkEnv.get.blockManager)
+    blockManager: BlockManager = SparkEnv.get.blockManager,
+    context: TaskContext = TaskContext.get())
   extends Iterable[(K, C)]
   with Serializable
   with Logging
   with Spillable[SizeTracker] {
 
+  if (context == null) {
+    throw new IllegalStateException(
+      "Spillable collections should not be instantiated outside of tasks")
+  }
+
+  // Backwards-compatibility constructor for binary compatibility
+  def this(
+      createCombiner: V => C,
+      mergeValue: (C, V) => C,
+      mergeCombiners: (C, C) => C,
+      serializer: Serializer,
+      blockManager: BlockManager) {
+    this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get())
+  }
+
+  override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
+
   private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
   private val spilledMaps = new ArrayBuffer[DiskMapIterator]
   private val sparkConf = SparkEnv.get.conf
@@ -118,6 +138,10 @@ class ExternalAppendOnlyMap[K, V, C](
    * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked.
    */
   def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
+    if (currentMap == null) {
+      throw new IllegalStateException(
+        "Cannot insert new elements into a map after calling iterator")
+    }
     // An update function for the map that we reuse across entries to avoid allocating
     // a new closure each time
     var curEntry: Product2[K, V] = null
@@ -215,17 +239,26 @@ class ExternalAppendOnlyMap[K, V, C](
   }
 
   /**
-   * Return an iterator that merges the in-memory map with the spilled maps.
+   * Return a destructive iterator that merges the in-memory map with the spilled maps.
    * If no spill has occurred, simply return the in-memory map's iterator.
    */
   override def iterator: Iterator[(K, C)] = {
+    if (currentMap == null) {
+      throw new IllegalStateException(
+        "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
+    }
     if (spilledMaps.isEmpty) {
-      currentMap.iterator
+      CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())
     } else {
       new ExternalIterator()
     }
   }
 
+  private def freeCurrentMap(): Unit = {
+    currentMap = null // So that the memory can be garbage-collected
+    releaseMemory()
+  }
+
   /**
    * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps
    */
@@ -237,7 +270,8 @@ class ExternalAppendOnlyMap[K, V, C](
 
     // Input streams are derived both from the in-memory map and spilled maps on disk
     // The in-memory map is sorted in place, while the spilled maps are already in sorted order
-    private val sortedMap = currentMap.destructiveSortedIterator(keyComparator)
+    private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](
+      currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
     private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
 
     inputStreams.foreach { it =>
@@ -493,12 +527,7 @@ class ExternalAppendOnlyMap[K, V, C](
       }
     }
 
-    val context = TaskContext.get()
-    // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in
-    // a TaskContext.
-    if (context != null) {
-      context.addTaskCompletionListener(context => cleanup())
-    }
+    context.addTaskCompletionListener(context => cleanup())
   }
 
   /** Convenience function to hash the given (K, C) pair by the key. */

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index c48c453..a44e72b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -27,6 +27,7 @@ import com.google.common.annotations.VisibleForTesting
 import com.google.common.io.ByteStreams
 
 import org.apache.spark._
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.serializer._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
@@ -87,6 +88,7 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
  * - Users are expected to call stop() at the end to delete all the intermediate files.
  */
 private[spark] class ExternalSorter[K, V, C](
+    context: TaskContext,
     aggregator: Option[Aggregator[K, V, C]] = None,
     partitioner: Option[Partitioner] = None,
     ordering: Option[Ordering[K]] = None,
@@ -94,6 +96,8 @@ private[spark] class ExternalSorter[K, V, C](
   extends Logging
   with Spillable[WritablePartitionedPairCollection[K, C]] {
 
+  override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
+
   private val conf = SparkEnv.get.conf
 
   private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
@@ -640,7 +644,6 @@ private[spark] class ExternalSorter[K, V, C](
    */
   def writePartitionedFile(
       blockId: BlockId,
-      context: TaskContext,
       outputFile: File): Array[Long] = {
 
     // Track location of each range in the output file
@@ -686,8 +689,11 @@ private[spark] class ExternalSorter[K, V, C](
   }
 
   def stop(): Unit = {
+    map = null // So that the memory can be garbage-collected
+    buffer = null // So that the memory can be garbage-collected
     spills.foreach(s => s.file.delete())
     spills.clear()
+    releaseMemory()
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index d2a68ca..a76891a 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -17,8 +17,8 @@
 
 package org.apache.spark.util.collection
 
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.{Logging, SparkEnv}
 
 /**
  * Spills contents of an in-memory collection to disk when the memory threshold
@@ -40,7 +40,7 @@ private[spark] trait Spillable[C] extends Logging {
   protected def addElementsRead(): Unit = { _elementsRead += 1 }
 
   // Memory manager that can be used to acquire/release memory
-  private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
+  protected[this] def taskMemoryManager: TaskMemoryManager
 
   // Initial threshold for the size of a collection before we start tracking its memory usage
   // For testing only
@@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging {
     if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
       // Claim up to double our current memory from the shuffle memory pool
       val amountToRequest = 2 * currentMemory - myMemoryThreshold
-      val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+      val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest)
       myMemoryThreshold += granted
       // If we were granted too little memory to grow further (either tryToAcquire returned 0,
       // or we already had more memory than myMemoryThreshold), spill the current collection
@@ -92,7 +92,7 @@ private[spark] trait Spillable[C] extends Logging {
       spill(collection)
       _elementsRead = 0
       _memoryBytesSpilled += currentMemory
-      releaseMemoryForThisThread()
+      releaseMemory()
     }
     shouldSpill
   }
@@ -103,11 +103,11 @@ private[spark] trait Spillable[C] extends Logging {
   def memoryBytesSpilled: Long = _memoryBytesSpilled
 
   /**
-   * Release our memory back to the shuffle pool so that other threads can grab it.
+   * Release our memory back to the execution pool so that other tasks can grab it.
    */
-  private def releaseMemoryForThisThread(): Unit = {
+  def releaseMemory(): Unit = {
     // The amount we requested does not include the initial memory tracking threshold
-    shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold)
+    taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold)
     myMemoryThreshold = initialMemoryThreshold
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
new file mode 100644
index 0000000..f381db0
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+public class TaskMemoryManagerSuite {
+
+  @Test
+  public void leakedPageMemoryIsDetected() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    manager.allocatePage(4096);  // leak memory
+    Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
+  }
+
+  @Test
+  public void encodePageNumberAndOffsetOffHeap() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
+    // encode. This test exercises that corner-case:
+    final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset);
+    Assert.assertEquals(null, manager.getPage(encodedAddress));
+    Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress));
+  }
+
+  @Test
+  public void encodePageNumberAndOffsetOnHeap() {
+    final TaskMemoryManager manager = new TaskMemoryManager(
+      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
+    Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
+    Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index 232ae4d..7fb2f92 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -21,18 +21,19 @@ import org.apache.spark.shuffle.sort.PackedRecordPointer;
 import org.junit.Test;
 import static org.junit.Assert.*;
 
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
 
 public class PackedRecordPointerSuite {
 
   @Test
   public void heap() {
+    final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
     final MemoryBlock page0 = memoryManager.allocatePage(128);
     final MemoryBlock page1 = memoryManager.allocatePage(128);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
@@ -49,8 +50,9 @@ public class PackedRecordPointerSuite {
 
   @Test
   public void offHeap() {
+    final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
     final MemoryBlock page0 = memoryManager.allocatePage(128);
     final MemoryBlock page1 = memoryManager.allocatePage(128);
     final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 1ef3c5f..5049a53 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -24,11 +24,11 @@ import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.spark.HashPartitioner;
+import org.apache.spark.SparkConf;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 
 public class ShuffleInMemorySorterSuite {
 
@@ -58,8 +58,9 @@ public class ShuffleInMemorySorterSuite {
       "Lychee",
       "Mango"
     };
+    final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
     final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+      new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
     final MemoryBlock dataPage = memoryManager.allocatePage(2048);
     final Object baseObject = dataPage.getBaseObject();
     final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 29d9823..d659269 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -39,7 +39,6 @@ import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.lessThan;
 import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
 import static org.mockito.Answers.RETURNS_SMART_NULLS;
 import static org.mockito.Mockito.*;
 
@@ -54,19 +53,15 @@ import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.serializer.*;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
 import org.apache.spark.storage.*;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 public class UnsafeShuffleWriterSuite {
 
   static final int NUM_PARTITITONS = 4;
-  final TaskMemoryManager taskMemoryManager =
-    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+  TaskMemoryManager taskMemoryManager;
   final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
   File mergedOutputFile;
   File tempDir;
@@ -76,7 +71,6 @@ public class UnsafeShuffleWriterSuite {
   final Serializer serializer = new KryoSerializer(new SparkConf());
   TaskMetrics taskMetrics;
 
-  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
   @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@@ -111,11 +105,11 @@ public class UnsafeShuffleWriterSuite {
     mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
     partitionSizesInMergedFile = null;
     spillFilesCreated.clear();
-    conf = new SparkConf().set("spark.buffer.pageSize", "128m");
+    conf = new SparkConf()
+      .set("spark.buffer.pageSize", "128m")
+      .set("spark.unsafe.offHeap", "false");
     taskMetrics = new TaskMetrics();
-
-    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
-    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
+    taskMemoryManager =  new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
 
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
     when(blockManager.getDiskWriter(
@@ -203,7 +197,6 @@ public class UnsafeShuffleWriterSuite {
       blockManager,
       shuffleBlockResolver,
       taskMemoryManager,
-      shuffleMemoryManager,
       new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
       0, // map id
       taskContext,
@@ -405,11 +398,12 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void writeEnoughDataToTriggerSpill() throws Exception {
-    when(shuffleMemoryManager.tryToAcquire(anyLong()))
-      .then(returnsFirstArg()) // Allocate initial sort buffer
-      .then(returnsFirstArg()) // Allocate initial data page
-      .thenReturn(0L) // Deny request to allocate new data page
-      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    taskMemoryManager = spy(taskMemoryManager);
+    doCallRealMethod() // initialize sort buffer
+      .doCallRealMethod() // allocate initial data page
+      .doReturn(0L) // deny request to allocate new page
+      .doCallRealMethod() // grant new sort buffer and data page
+      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
     final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
@@ -417,7 +411,7 @@ public class UnsafeShuffleWriterSuite {
       dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
     }
     writer.write(dataToWrite.iterator());
-    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
     assertEquals(2, spillFilesCreated.size());
     writer.stop(true);
     readRecordsFromFile();
@@ -432,18 +426,19 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
-    when(shuffleMemoryManager.tryToAcquire(anyLong()))
-      .then(returnsFirstArg()) // Allocate initial sort buffer
-      .then(returnsFirstArg()) // Allocate initial data page
-      .thenReturn(0L) // Deny request to grow sort buffer
-      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    taskMemoryManager = spy(taskMemoryManager);
+    doCallRealMethod() // initialize sort buffer
+      .doCallRealMethod() // allocate initial data page
+      .doReturn(0L) // deny request to allocate new page
+      .doCallRealMethod() // grant new sort buffer and data page
+      .when(taskMemoryManager).acquireExecutionMemory(anyLong());
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
     for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
       dataToWrite.add(new Tuple2<Object, Object>(i, i));
     }
     writer.write(dataToWrite.iterator());
-    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
     assertEquals(2, spillFilesCreated.size());
     writer.stop(true);
     readRecordsFromFile();
@@ -509,13 +504,13 @@ public class UnsafeShuffleWriterSuite {
     final long recordLengthBytes = 8;
     final long pageSizeBytes = 256;
     final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
-    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
+    taskMemoryManager = spy(taskMemoryManager);
+    when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
     final UnsafeShuffleWriter<Object, Object> writer =
       new UnsafeShuffleWriter<Object, Object>(
         blockManager,
         shuffleBlockResolver,
         taskMemoryManager,
-        shuffleMemoryManager,
         new SerializedShuffleHandle<>(0, 1, shuffleDep),
         0, // map id
         taskContext,

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index ab480b6..6e52496 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -21,15 +21,13 @@ import java.lang.Exception;
 import java.nio.ByteBuffer;
 import java.util.*;
 
+import org.apache.spark.memory.TaskMemoryManager;
 import org.junit.*;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.junit.Assert.*;
-import static org.mockito.AdditionalMatchers.geq;
-import static org.mockito.Mockito.*;
 
-import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.memory.*;
 import org.apache.spark.unsafe.Platform;
@@ -39,42 +37,29 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   private final Random rand = new Random(42);
 
-  private ShuffleMemoryManager shuffleMemoryManager;
+  private GrantEverythingMemoryManager memoryManager;
   private TaskMemoryManager taskMemoryManager;
-  private TaskMemoryManager sizeLimitedTaskMemoryManager;
   private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
 
   @Before
   public void setup() {
-    shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, PAGE_SIZE_BYTES);
-    taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
-    // Mocked memory manager for tests that check the maximum array size, since actually allocating
-    // such large arrays will cause us to run out of memory in our tests.
-    sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class);
-    when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer(
-      new Answer<MemoryBlock>() {
-        @Override
-        public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
-          if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
-            throw new OutOfMemoryError("Requested array size exceeds VM limit");
-          }
-          return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]);
-        }
-      }
-    );
+    memoryManager =
+      new GrantEverythingMemoryManager(
+        new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()));
+    taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
   }
 
   @After
   public void tearDown() {
     Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
-    if (shuffleMemoryManager != null) {
-      long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
-      shuffleMemoryManager = null;
-      Assert.assertEquals(0L, leakedShuffleMemory);
+    if (taskMemoryManager != null) {
+      long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
+      taskMemoryManager = null;
+      Assert.assertEquals(0L, leakedMemory);
     }
   }
 
-  protected abstract MemoryAllocator getMemoryAllocator();
+  protected abstract boolean useOffHeapMemoryAllocator();
 
   private static byte[] getByteArray(MemoryLocation loc, int size) {
     final byte[] arr = new byte[size];
@@ -110,8 +95,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void emptyMap() {
-    BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
     try {
       Assert.assertEquals(0, map.numElements());
       final int keyLengthInWords = 10;
@@ -126,8 +110,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void setAndRetrieveAKey() {
-    BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
     final int recordLengthWords = 10;
     final int recordLengthBytes = recordLengthWords * 8;
     final byte[] keyData = getRandomByteArray(recordLengthWords);
@@ -179,8 +162,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   private void iteratorTestBase(boolean destructive) throws Exception {
     final int size = 4096;
-    BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES);
     try {
       for (long i = 0; i < size; i++) {
         final long[] value = new long[] { i };
@@ -265,8 +247,8 @@ public abstract class AbstractBytesToBytesMapSuite {
     final int NUM_ENTRIES = 1000 * 1000;
     final int KEY_LENGTH = 24;
     final int VALUE_LENGTH = 40;
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
+    final BytesToBytesMap map =
+      new BytesToBytesMap(taskMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
     // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte
     // pages won't be evenly-divisible by records of this size, which will cause us to waste some
     // space at the end of the page. This is necessary in order for us to take the end-of-record
@@ -335,9 +317,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
     // into ByteBuffers in order to use them as keys here.
     final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES);
-
+    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size, PAGE_SIZE_BYTES);
     try {
       // Fill the map to 90% full so that we can trigger probing
       for (int i = 0; i < size * 0.9; i++) {
@@ -386,8 +366,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void randomizedTestWithRecordsLargerThanPageSize() {
     final long pageSizeBytes = 128;
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes);
+    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, pageSizeBytes);
     // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
     // into ByteBuffers in order to use them as keys here.
     final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
@@ -436,9 +415,9 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void failureToAllocateFirstPage() {
-    shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024);
-    BytesToBytesMap map =
-      new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+    memoryManager.markExecutionAsOutOfMemory();
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
+    memoryManager.markExecutionAsOutOfMemory();
     try {
       final long[] emptyArray = new long[0];
       final BytesToBytesMap.Location loc =
@@ -454,12 +433,14 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void failureToGrow() {
-    shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024 * 10);
-    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, 1024);
     try {
       boolean success = true;
       int i;
-      for (i = 0; i < 1024; i++) {
+      for (i = 0; i < 127; i++) {
+        if (i > 0) {
+          memoryManager.markExecutionAsOutOfMemory();
+        }
         final long[] arr = new long[]{i};
         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
         success =
@@ -478,7 +459,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void initialCapacityBoundsChecking() {
     try {
-      new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES);
+      new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES);
       Assert.fail("Expected IllegalArgumentException to be thrown");
     } catch (IllegalArgumentException e) {
       // expected exception
@@ -486,36 +467,13 @@ public abstract class AbstractBytesToBytesMapSuite {
 
     try {
       new BytesToBytesMap(
-        sizeLimitedTaskMemoryManager,
-        shuffleMemoryManager,
+        taskMemoryManager,
         BytesToBytesMap.MAX_CAPACITY + 1,
         PAGE_SIZE_BYTES);
       Assert.fail("Expected IllegalArgumentException to be thrown");
     } catch (IllegalArgumentException e) {
       // expected exception
     }
-
-    // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
-    // Can allocate _at_ the max capacity
-    //    BytesToBytesMap map = new BytesToBytesMap(
-    //      sizeLimitedTaskMemoryManager,
-    //      shuffleMemoryManager,
-    //      BytesToBytesMap.MAX_CAPACITY,
-    //      PAGE_SIZE_BYTES);
-    //    map.free();
-  }
-
-  // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
-  @Ignore
-  public void resizingLargeMap() {
-    // As long as a map's capacity is below the max, we should be able to resize up to the max
-    BytesToBytesMap map = new BytesToBytesMap(
-      sizeLimitedTaskMemoryManager,
-      shuffleMemoryManager,
-      BytesToBytesMap.MAX_CAPACITY - 64,
-      PAGE_SIZE_BYTES);
-    map.growAndRehash();
-    map.free();
   }
 
   @Test
@@ -523,8 +481,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     final long recordLengthBytes = 24;
     final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
     final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes);
+    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes);
 
     // Since BytesToBytesMap is append-only, we expect the total memory consumption to be
     // monotonically increasing. More specifically, every time we allocate a new page it
@@ -564,8 +521,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void testAcquirePageInConstructor() {
-    final BytesToBytesMap map = new BytesToBytesMap(
-      taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+    final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
     assertEquals(1, map.getNumDataPages());
     map.free();
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
index 5a10de4..f0bad4d 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
@@ -17,13 +17,10 @@
 
 package org.apache.spark.unsafe.map;
 
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-
 public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite {
 
   @Override
-  protected MemoryAllocator getMemoryAllocator() {
-    return MemoryAllocator.UNSAFE;
+  protected boolean useOffHeapMemoryAllocator() {
+    return true;
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
index 12cc9b2..d76bb4f 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
@@ -17,13 +17,10 @@
 
 package org.apache.spark.unsafe.map;
 
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-
 public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite {
 
   @Override
-  protected MemoryAllocator getMemoryAllocator() {
-    return MemoryAllocator.HEAP;
+  protected boolean useOffHeapMemoryAllocator() {
+    return false;
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index a5bbaa9..94d50b9 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -46,20 +46,19 @@ import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.storage.*;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
 public class UnsafeExternalSorterSuite {
 
   final LinkedList<File> spillFilesCreated = new LinkedList<File>();
-  final TaskMemoryManager taskMemoryManager =
-    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+  final GrantEverythingMemoryManager memoryManager =
+    new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
+  final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
   // Use integer comparison for comparing prefixes (which are partition ids, in this case)
   final PrefixComparator prefixComparator = new PrefixComparator() {
     @Override
@@ -82,7 +81,6 @@ public class UnsafeExternalSorterSuite {
 
   SparkConf sparkConf;
   File tempDir;
-  ShuffleMemoryManager shuffleMemoryManager;
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
@@ -102,7 +100,6 @@ public class UnsafeExternalSorterSuite {
     MockitoAnnotations.initMocks(this);
     sparkConf = new SparkConf();
     tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
-    shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, pageSizeBytes);
     spillFilesCreated.clear();
     taskContext = mock(TaskContext.class);
     when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
@@ -143,13 +140,7 @@ public class UnsafeExternalSorterSuite {
   @After
   public void tearDown() {
     try {
-      long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
-      if (shuffleMemoryManager != null) {
-        long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
-        shuffleMemoryManager = null;
-        assertEquals(0L, leakedShuffleMemory);
-      }
-      assertEquals(0, leakedUnsafeMemory);
+      assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
     } finally {
       Utils.deleteRecursively(tempDir);
       tempDir = null;
@@ -178,7 +169,6 @@ public class UnsafeExternalSorterSuite {
   private UnsafeExternalSorter newSorter() throws IOException {
     return UnsafeExternalSorter.create(
       taskMemoryManager,
-      shuffleMemoryManager,
       blockManager,
       taskContext,
       recordComparator,
@@ -236,12 +226,16 @@ public class UnsafeExternalSorterSuite {
 
   @Test
   public void spillingOccursInResponseToMemoryPressure() throws Exception {
-    shuffleMemoryManager = ShuffleMemoryManager.create(pageSizeBytes * 2, pageSizeBytes);
     final UnsafeExternalSorter sorter = newSorter();
-    final int numRecords = (int) pageSizeBytes / 4;
-    for (int i = 0; i <= numRecords; i++) {
+    // This should be enough records to completely fill up a data page:
+    final int numRecords = (int) (pageSizeBytes / (4 + 4));
+    for (int i = 0; i < numRecords; i++) {
       insertNumber(sorter, numRecords - i);
     }
+    assertEquals(1, sorter.getNumberOfAllocatedPages());
+    memoryManager.markExecutionAsOutOfMemory();
+    // The insertion of this record should trigger a spill:
+    insertNumber(sorter, 0);
     // Ensure that spill files were created
     assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1));
     // Read back the sorted data:
@@ -255,6 +249,7 @@ public class UnsafeExternalSorterSuite {
       assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()));
       i++;
     }
+    assertEquals(numRecords + 1, i);
     sorter.cleanupResources();
     assertSpillFilesWereCleanedUp();
   }
@@ -323,7 +318,6 @@ public class UnsafeExternalSorterSuite {
     final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
     final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
       taskMemoryManager,
-      shuffleMemoryManager,
       blockManager,
       taskContext,
       recordComparator,

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 778e813..d5de56a 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -26,11 +26,11 @@ import static org.junit.Assert.*;
 import static org.mockito.Mockito.mock;
 
 import org.apache.spark.HashPartitioner;
+import org.apache.spark.SparkConf;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
 
 public class UnsafeInMemorySorterSuite {
 
@@ -43,7 +43,8 @@ public class UnsafeInMemorySorterSuite {
   @Test
   public void testSortingEmptyInput() {
     final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
+      new TaskMemoryManager(
+        new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
       mock(RecordComparator.class),
       mock(PrefixComparator.class),
       100);
@@ -64,8 +65,8 @@ public class UnsafeInMemorySorterSuite {
       "Lychee",
       "Mango"
     };
-    final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final TaskMemoryManager memoryManager = new TaskMemoryManager(
+      new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
     final MemoryBlock dataPage = memoryManager.allocatePage(2048);
     final Object baseObject = dataPage.getBaseObject();
     // Write the records into the data page:

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/FailureSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index f58756e..0242cbc 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // cause is preserved
     val thrownDueToTaskFailure = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocate(128)
+        TaskContext.get().taskMemoryManager().allocatePage(128)
         throw new Exception("intentional task failure")
         iter
       }.count()
@@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
     // If the task succeeded but memory was leaked, then the task should fail due to that leak
     val thrownDueToMemoryLeak = intercept[SparkException] {
       sc.parallelize(Seq(0)).mapPartitions { iter =>
-        TaskContext.get().taskMemoryManager().allocate(128)
+        TaskContext.get().taskMemoryManager().allocatePage(128)
         iter
       }.count()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
new file mode 100644
index 0000000..fe102d8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.{BlockStatus, BlockId}
+
+class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
+  private[memory] override def doAcquireExecutionMemory(
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
+    if (oom) {
+      oom = false
+      0
+    } else {
+      _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory
+      numBytes
+    }
+  }
+  override def acquireStorageMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
+  override def acquireUnrollMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
+  override def releaseStorageMemory(numBytes: Long): Unit = { }
+  override def maxExecutionMemory: Long = Long.MaxValue
+  override def maxStorageMemory: Long = Long.MaxValue
+
+  private var oom = false
+
+  def markExecutionAsOutOfMemory(): Unit = {
+    oom = true
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
index 36e4566..1265087 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
@@ -19,10 +19,14 @@ package org.apache.spark.memory
 
 import java.util.concurrent.atomic.AtomicLong
 
+import scala.concurrent.duration.Duration
+import scala.concurrent.{Await, ExecutionContext, Future}
+
 import org.mockito.Matchers.{any, anyLong}
 import org.mockito.Mockito.{mock, when}
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.storage.MemoryStore
@@ -126,6 +130,136 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
     assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED,
       "ensure free space should not have been called!")
   }
+
+  /**
+   * Create a MemoryManager with the specified execution memory limit and no storage memory.
+   */
+  protected def createMemoryManager(maxExecutionMemory: Long): MemoryManager
+
+  // -- Tests of sharing of execution memory between tasks ----------------------------------------
+  // Prior to Spark 1.6, these tests were part of ShuffleMemoryManagerSuite.
+
+  implicit val ec = ExecutionContext.global
+
+  test("single task requesting execution memory") {
+    val manager = createMemoryManager(1000L)
+    val taskMemoryManager = new TaskMemoryManager(manager, 0)
+
+    assert(taskMemoryManager.acquireExecutionMemory(100L) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
+    assert(taskMemoryManager.acquireExecutionMemory(200L) === 100L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+
+    taskMemoryManager.releaseExecutionMemory(500L)
+    assert(taskMemoryManager.acquireExecutionMemory(300L) === 300L)
+    assert(taskMemoryManager.acquireExecutionMemory(300L) === 200L)
+
+    taskMemoryManager.cleanUpAllAllocatedMemory()
+    assert(taskMemoryManager.acquireExecutionMemory(1000L) === 1000L)
+    assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+  }
+
+  test("two tasks requesting full execution memory") {
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    // Have both tasks request 500 bytes, then wait until both requests have been granted:
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t1Result1, futureTimeout) === 500L)
+    assert(Await.result(t2Result1, futureTimeout) === 500L)
+
+    // Have both tasks each request 500 bytes more; both should immediately return 0 as they are
+    // both now at 1 / N
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t1Result2, 200.millis) === 0L)
+    assert(Await.result(t2Result2, 200.millis) === 0L)
+  }
+
+  test("two tasks cannot grow past 1 / N of execution memory") {
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    // Have both tasks request 250 bytes, then wait until both requests have been granted:
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L) }
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+    assert(Await.result(t1Result1, futureTimeout) === 250L)
+    assert(Await.result(t2Result1, futureTimeout) === 250L)
+
+    // Have both tasks each request 500 bytes more.
+    // We should only grant 250 bytes to each of them on this second request
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t1Result2, futureTimeout) === 250L)
+    assert(Await.result(t2Result2, futureTimeout) === 250L)
+  }
+
+  test("tasks can block to get at least 1 / 2N of execution memory") {
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+    assert(Await.result(t1Result1, futureTimeout) === 1000L)
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+    // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
+    // to make sure the other thread blocks for some time otherwise.
+    Thread.sleep(300)
+    t1MemManager.releaseExecutionMemory(250L)
+    // The memory freed from t1 should now be granted to t2.
+    assert(Await.result(t2Result1, futureTimeout) === 250L)
+    // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory.
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L) }
+    assert(Await.result(t2Result2, 200.millis) === 0L)
+  }
+
+  test("TaskMemoryManager.cleanUpAllAllocatedMemory") {
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+    assert(Await.result(t1Result1, futureTimeout) === 1000L)
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
+    // to make sure the other thread blocks for some time otherwise.
+    Thread.sleep(300)
+    // t1 releases all of its memory, so t2 should be able to grab all of the memory
+    t1MemManager.cleanUpAllAllocatedMemory()
+    assert(Await.result(t2Result1, futureTimeout) === 500L)
+    val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t2Result2, futureTimeout) === 500L)
+    val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L) }
+    assert(Await.result(t2Result3, 200.millis) === 0L)
+  }
+
+  test("tasks should not be granted a negative amount of execution memory") {
+    // This is a regression test for SPARK-4715.
+    val memoryManager = createMemoryManager(1000L)
+    val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+    val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+    val futureTimeout: Duration = 20.seconds
+
+    val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L) }
+    assert(Await.result(t1Result1, futureTimeout) === 700L)
+
+    val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L) }
+    assert(Await.result(t2Result1, futureTimeout) === 300L)
+
+    val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L) }
+    assert(Await.result(t1Result2, 200.millis) === 0L)
+  }
 }
 
 private object MemoryManagerSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
new file mode 100644
index 0000000..4b4c3b0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
+
+/**
+ * Helper methods for mocking out memory-management-related classes in tests.
+ */
+object MemoryTestingUtils {
+  def fakeTaskContext(env: SparkEnv): TaskContext = {
+    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
+    new TaskContextImpl(
+      stageId = 0,
+      partitionId = 0,
+      taskAttemptId = 0,
+      attemptNumber = 0,
+      taskMemoryManager = taskMemoryManager,
+      metricsSystem = env.metricsSystem,
+      internalAccumulators = Seq.empty)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
index 6cae1f8..885c450 100644
--- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
@@ -36,27 +36,35 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite {
       maxExecutionMem: Long,
       maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = {
     val mm = new StaticMemoryManager(
-      conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem)
+      conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem, numCores = 1)
     val ms = makeMemoryStore(mm)
     (mm, ms)
   }
 
+  override protected def createMemoryManager(maxMemory: Long): MemoryManager = {
+    new StaticMemoryManager(
+      conf,
+      maxExecutionMemory = maxMemory,
+      maxStorageMemory = 0,
+      numCores = 1)
+  }
+
   test("basic execution memory") {
     val maxExecutionMem = 1000L
     val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue)
     assert(mm.executionMemoryUsed === 0L)
-    assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L)
+    assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L)
     assert(mm.executionMemoryUsed === 10L)
-    assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
     // Acquire up to the max
-    assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L)
+    assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L)
     assert(mm.executionMemoryUsed === maxExecutionMem)
-    assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L)
+    assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L)
     assert(mm.executionMemoryUsed === maxExecutionMem)
     mm.releaseExecutionMemory(800L)
     assert(mm.executionMemoryUsed === 200L)
     // Acquire after release
-    assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L)
+    assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L)
     assert(mm.executionMemoryUsed === 201L)
     // Release beyond what was acquired
     mm.releaseExecutionMemory(maxExecutionMem)
@@ -108,10 +116,10 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite {
     val dummyBlock = TestBlockId("ain't nobody love like you do")
     val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem)
     // Only execution memory should increase
-    assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
     assert(mm.storageMemoryUsed === 0L)
     assert(mm.executionMemoryUsed === 100L)
-    assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 100L)
     assert(mm.storageMemoryUsed === 0L)
     assert(mm.executionMemoryUsed === 200L)
     // Only storage memory should increase

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
index e7baa50..0c97f2b 100644
--- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
@@ -34,11 +34,15 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
    * Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies.
    */
   private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = {
-    val mm = new UnifiedMemoryManager(conf, maxMemory)
+    val mm = new UnifiedMemoryManager(conf, maxMemory, numCores = 1)
     val ms = makeMemoryStore(mm)
     (mm, ms)
   }
 
+  override protected def createMemoryManager(maxMemory: Long): MemoryManager = {
+    new UnifiedMemoryManager(conf, maxMemory, numCores = 1)
+  }
+
   private def getStorageRegionSize(mm: UnifiedMemoryManager): Long = {
     mm invokePrivate PrivateMethod[Long]('storageRegionSize)()
   }
@@ -56,18 +60,18 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
     val maxMemory = 1000L
     val (mm, _) = makeThings(maxMemory)
     assert(mm.executionMemoryUsed === 0L)
-    assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L)
+    assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L)
     assert(mm.executionMemoryUsed === 10L)
-    assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
     // Acquire up to the max
-    assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L)
+    assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L)
     assert(mm.executionMemoryUsed === maxMemory)
-    assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L)
+    assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L)
     assert(mm.executionMemoryUsed === maxMemory)
     mm.releaseExecutionMemory(800L)
     assert(mm.executionMemoryUsed === 200L)
     // Acquire after release
-    assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L)
+    assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L)
     assert(mm.executionMemoryUsed === 201L)
     // Release beyond what was acquired
     mm.releaseExecutionMemory(maxMemory)
@@ -132,12 +136,12 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
     require(mm.storageMemoryUsed > storageRegionSize,
       s"bad test: storage memory used should exceed the storage region")
     // Execution needs to request 250 bytes to evict storage memory
-    assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+    assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
     assert(mm.executionMemoryUsed === 100L)
     assert(mm.storageMemoryUsed === 750L)
     assertEnsureFreeSpaceNotCalled(ms)
     // Execution wants 200 bytes but only 150 are free, so storage is evicted
-    assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L)
+    assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L)
     assertEnsureFreeSpaceCalled(ms, 200L)
     assert(mm.executionMemoryUsed === 300L)
     mm.releaseAllStorageMemory()
@@ -151,7 +155,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
       s"bad test: storage memory used should be within the storage region")
     // Execution cannot evict storage because the latter is within the storage fraction,
     // so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300
-    assert(mm.acquireExecutionMemory(400L, evictedBlocks) === 300L)
+    assert(mm.doAcquireExecutionMemory(400L, evictedBlocks) === 300L)
     assert(mm.executionMemoryUsed === 600L)
     assert(mm.storageMemoryUsed === 400L)
     assertEnsureFreeSpaceNotCalled(ms)
@@ -170,7 +174,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
     require(executionRegionSize === expectedExecutionRegionSize,
       "bad test: storage region size is unexpected")
     // Acquire enough execution memory to exceed the execution region
-    assert(mm.acquireExecutionMemory(800L, evictedBlocks) === 800L)
+    assert(mm.doAcquireExecutionMemory(800L, evictedBlocks) === 800L)
     assert(mm.executionMemoryUsed === 800L)
     assert(mm.storageMemoryUsed === 0L)
     assertEnsureFreeSpaceNotCalled(ms)
@@ -188,7 +192,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
     mm.releaseExecutionMemory(maxMemory)
     mm.releaseStorageMemory(maxMemory)
     // Acquire some execution memory again, but this time keep it within the execution region
-    assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L)
+    assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L)
     assert(mm.executionMemoryUsed === 200L)
     assert(mm.storageMemoryUsed === 0L)
     assertEnsureFreeSpaceNotCalled(ms)

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
deleted file mode 100644
index 5877aa0..0000000
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
+++ /dev/null
@@ -1,326 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle
-
-import java.util.concurrent.CountDownLatch
-import java.util.concurrent.atomic.AtomicInteger
-
-import org.mockito.Mockito._
-import org.scalatest.concurrent.Timeouts
-import org.scalatest.time.SpanSugar._
-
-import org.apache.spark.{SparkFunSuite, TaskContext}
-import org.apache.spark.executor.TaskMetrics
-
-class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
-
-  val nextTaskAttemptId = new AtomicInteger()
-
-  /** Launch a thread with the given body block and return it. */
-  private def startThread(name: String)(body: => Unit): Thread = {
-    val thread = new Thread("ShuffleMemorySuite " + name) {
-      override def run() {
-        try {
-          val taskAttemptId = nextTaskAttemptId.getAndIncrement
-          val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS)
-          val taskMetrics = new TaskMetrics
-          when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId)
-          when(mockTaskContext.taskMetrics()).thenReturn(taskMetrics)
-          TaskContext.setTaskContext(mockTaskContext)
-          body
-        } finally {
-          TaskContext.unset()
-        }
-      }
-    }
-    thread.start()
-    thread
-  }
-
-  test("single task requesting memory") {
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    assert(manager.tryToAcquire(100L) === 100L)
-    assert(manager.tryToAcquire(400L) === 400L)
-    assert(manager.tryToAcquire(400L) === 400L)
-    assert(manager.tryToAcquire(200L) === 100L)
-    assert(manager.tryToAcquire(100L) === 0L)
-    assert(manager.tryToAcquire(100L) === 0L)
-
-    manager.release(500L)
-    assert(manager.tryToAcquire(300L) === 300L)
-    assert(manager.tryToAcquire(300L) === 200L)
-
-    manager.releaseMemoryForThisTask()
-    assert(manager.tryToAcquire(1000L) === 1000L)
-    assert(manager.tryToAcquire(100L) === 0L)
-  }
-
-  test("two threads requesting full memory") {
-    // Two threads request 500 bytes first, wait for each other to get it, and then request
-    // 500 more; we should immediately return 0 as both are now at 1 / N
-
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    class State {
-      var t1Result1 = -1L
-      var t2Result1 = -1L
-      var t1Result2 = -1L
-      var t2Result2 = -1L
-    }
-    val state = new State
-
-    val t1 = startThread("t1") {
-      val r1 = manager.tryToAcquire(500L)
-      state.synchronized {
-        state.t1Result1 = r1
-        state.notifyAll()
-        while (state.t2Result1 === -1L) {
-          state.wait()
-        }
-      }
-      val r2 = manager.tryToAcquire(500L)
-      state.synchronized { state.t1Result2 = r2 }
-    }
-
-    val t2 = startThread("t2") {
-      val r1 = manager.tryToAcquire(500L)
-      state.synchronized {
-        state.t2Result1 = r1
-        state.notifyAll()
-        while (state.t1Result1 === -1L) {
-          state.wait()
-        }
-      }
-      val r2 = manager.tryToAcquire(500L)
-      state.synchronized { state.t2Result2 = r2 }
-    }
-
-    failAfter(20 seconds) {
-      t1.join()
-      t2.join()
-    }
-
-    assert(state.t1Result1 === 500L)
-    assert(state.t2Result1 === 500L)
-    assert(state.t1Result2 === 0L)
-    assert(state.t2Result2 === 0L)
-  }
-
-
-  test("tasks cannot grow past 1 / N") {
-    // Two tasks request 250 bytes first, wait for each other to get it, and then request
-    // 500 more; we should only grant 250 bytes to each of them on this second request
-
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    class State {
-      var t1Result1 = -1L
-      var t2Result1 = -1L
-      var t1Result2 = -1L
-      var t2Result2 = -1L
-    }
-    val state = new State
-
-    val t1 = startThread("t1") {
-      val r1 = manager.tryToAcquire(250L)
-      state.synchronized {
-        state.t1Result1 = r1
-        state.notifyAll()
-        while (state.t2Result1 === -1L) {
-          state.wait()
-        }
-      }
-      val r2 = manager.tryToAcquire(500L)
-      state.synchronized { state.t1Result2 = r2 }
-    }
-
-    val t2 = startThread("t2") {
-      val r1 = manager.tryToAcquire(250L)
-      state.synchronized {
-        state.t2Result1 = r1
-        state.notifyAll()
-        while (state.t1Result1 === -1L) {
-          state.wait()
-        }
-      }
-      val r2 = manager.tryToAcquire(500L)
-      state.synchronized { state.t2Result2 = r2 }
-    }
-
-    failAfter(20 seconds) {
-      t1.join()
-      t2.join()
-    }
-
-    assert(state.t1Result1 === 250L)
-    assert(state.t2Result1 === 250L)
-    assert(state.t1Result2 === 250L)
-    assert(state.t2Result2 === 250L)
-  }
-
-  test("tasks can block to get at least 1 / 2N memory") {
-    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
-    // for a bit and releases 250 bytes, which should then be granted to t2. Further requests
-    // by t2 will return false right away because it now has 1 / 2N of the memory.
-
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    class State {
-      var t1Requested = false
-      var t2Requested = false
-      var t1Result = -1L
-      var t2Result = -1L
-      var t2Result2 = -1L
-      var t2WaitTime = 0L
-    }
-    val state = new State
-
-    val t1 = startThread("t1") {
-      state.synchronized {
-        state.t1Result = manager.tryToAcquire(1000L)
-        state.t1Requested = true
-        state.notifyAll()
-        while (!state.t2Requested) {
-          state.wait()
-        }
-      }
-      // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
-      // sure the other thread blocks for some time otherwise
-      Thread.sleep(300)
-      manager.release(250L)
-    }
-
-    val t2 = startThread("t2") {
-      state.synchronized {
-        while (!state.t1Requested) {
-          state.wait()
-        }
-        state.t2Requested = true
-        state.notifyAll()
-      }
-      val startTime = System.currentTimeMillis()
-      val result = manager.tryToAcquire(250L)
-      val endTime = System.currentTimeMillis()
-      state.synchronized {
-        state.t2Result = result
-        // A second call should return 0 because we're now already at 1 / 2N
-        state.t2Result2 = manager.tryToAcquire(100L)
-        state.t2WaitTime = endTime - startTime
-      }
-    }
-
-    failAfter(20 seconds) {
-      t1.join()
-      t2.join()
-    }
-
-    // Both threads should've been able to acquire their memory; the second one will have waited
-    // until the first one acquired 1000 bytes and then released 250
-    state.synchronized {
-      assert(state.t1Result === 1000L, "t1 could not allocate memory")
-      assert(state.t2Result === 250L, "t2 could not allocate memory")
-      assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
-      assert(state.t2Result2 === 0L, "t1 got extra memory the second time")
-    }
-  }
-
-  test("releaseMemoryForThisTask") {
-    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
-    // for a bit and releases all its memory. t2 should now be able to grab all the memory.
-
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
-    class State {
-      var t1Requested = false
-      var t2Requested = false
-      var t1Result = -1L
-      var t2Result1 = -1L
-      var t2Result2 = -1L
-      var t2Result3 = -1L
-      var t2WaitTime = 0L
-    }
-    val state = new State
-
-    val t1 = startThread("t1") {
-      state.synchronized {
-        state.t1Result = manager.tryToAcquire(1000L)
-        state.t1Requested = true
-        state.notifyAll()
-        while (!state.t2Requested) {
-          state.wait()
-        }
-      }
-      // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
-      // sure the other task blocks for some time otherwise
-      Thread.sleep(300)
-      manager.releaseMemoryForThisTask()
-    }
-
-    val t2 = startThread("t2") {
-      state.synchronized {
-        while (!state.t1Requested) {
-          state.wait()
-        }
-        state.t2Requested = true
-        state.notifyAll()
-      }
-      val startTime = System.currentTimeMillis()
-      val r1 = manager.tryToAcquire(500L)
-      val endTime = System.currentTimeMillis()
-      val r2 = manager.tryToAcquire(500L)
-      val r3 = manager.tryToAcquire(500L)
-      state.synchronized {
-        state.t2Result1 = r1
-        state.t2Result2 = r2
-        state.t2Result3 = r3
-        state.t2WaitTime = endTime - startTime
-      }
-    }
-
-    failAfter(20 seconds) {
-      t1.join()
-      t2.join()
-    }
-
-    // Both tasks should've been able to acquire their memory; the second one will have waited
-    // until the first one acquired 1000 bytes and then released all of it
-    state.synchronized {
-      assert(state.t1Result === 1000L, "t1 could not allocate memory")
-      assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time")
-      assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time")
-      assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})")
-      assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
-    }
-  }
-
-  test("tasks should not be granted a negative size") {
-    val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-    manager.tryToAcquire(700L)
-
-    val latch = new CountDownLatch(1)
-    startThread("t1") {
-      manager.tryToAcquire(300L)
-      latch.countDown()
-    }
-    latch.await() // Wait until `t1` calls `tryToAcquire`
-
-    val granted = manager.tryToAcquire(300L)
-    assert(0 === granted, "granted is negative")
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index cc44c67..6e3f500 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -61,7 +61,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
       maxMem: Long,
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
-    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
     val store = new BlockManager(name, rpcEnv, master, serializer, conf,
       memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
     memManager.setMemoryStore(store.memoryStore)
@@ -261,7 +261,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
     val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work
     when(failableTransfer.hostName).thenReturn("some-hostname")
     when(failableTransfer.port).thenReturn(1000)
-    val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000)
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1)
     val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf,
       memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0)
     memManager.setMemoryStore(failableStore.memoryStore)

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index f3fab33..d49015a 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -68,7 +68,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       maxMem: Long,
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
-    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
     val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf,
       memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
     memManager.setMemoryStore(blockManager.memoryStore)
@@ -823,7 +823,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
   test("block store put failure") {
     // Use Java serializer so we can create an unserializable error.
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
-    val memoryManager = new StaticMemoryManager(conf, Long.MaxValue, 1200)
+    val memoryManager = new StaticMemoryManager(
+      conf,
+      maxExecutionMemory = Long.MaxValue,
+      maxStorageMemory = 1200,
+      numCores = 1)
     store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
       new JavaSerializer(conf), conf, memoryManager, mapOutputTracker,
       shuffleManager, transfer, securityMgr, 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/85e654c5/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 5cb506e..dc3185a 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
 import org.apache.spark.io.CompressionCodec
-
+import org.apache.spark.memory.MemoryTestingUtils
 
 class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
   import TestUtils.{assertNotSpilled, assertSpilled}
@@ -32,8 +32,11 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
   private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] =
     buf1 ++= buf2
 
-  private def createExternalMap[T] = new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]](
-    createCombiner[T], mergeValue[T], mergeCombiners[T])
+  private def createExternalMap[T] = {
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
+    new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]](
+      createCombiner[T], mergeValue[T], mergeCombiners[T], context = context)
+  }
 
   private def createSparkConf(loadDefaults: Boolean, codec: Option[String] = None): SparkConf = {
     val conf = new SparkConf(loadDefaults)
@@ -49,23 +52,27 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
     conf
   }
 
-  test("simple insert") {
+  test("single insert insert") {
     val conf = createSparkConf(loadDefaults = false)
     sc = new SparkContext("local", "test", conf)
     val map = createExternalMap[Int]
-
-    // Single insert
     map.insert(1, 10)
-    var it = map.iterator
+    val it = map.iterator
     assert(it.hasNext)
     val kv = it.next()
     assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10))
     assert(!it.hasNext)
+    sc.stop()
+  }
 
-    // Multiple insert
+  test("multiple insert") {
+    val conf = createSparkConf(loadDefaults = false)
+    sc = new SparkContext("local", "test", conf)
+    val map = createExternalMap[Int]
+    map.insert(1, 10)
     map.insert(2, 20)
     map.insert(3, 30)
-    it = map.iterator
+    val it = map.iterator
     assert(it.hasNext)
     assert(it.toSet === Set[(Int, ArrayBuffer[Int])](
       (1, ArrayBuffer[Int](10)),
@@ -144,39 +151,22 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
     sc = new SparkContext("local", "test", conf)
 
     val map = createExternalMap[Int]
+    val nullInt = null.asInstanceOf[Int]
     map.insert(1, 5)
     map.insert(2, 6)
     map.insert(3, 7)
-    assert(map.size === 3)
-    assert(map.iterator.toSet === Set[(Int, Seq[Int])](
-      (1, Seq[Int](5)),
-      (2, Seq[Int](6)),
-      (3, Seq[Int](7))
-    ))
-
-    // Null keys
-    val nullInt = null.asInstanceOf[Int]
+    map.insert(4, nullInt)
     map.insert(nullInt, 8)
-    assert(map.size === 4)
-    assert(map.iterator.toSet === Set[(Int, Seq[Int])](
+    map.insert(nullInt, nullInt)
+      val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.sorted))
+    assert(result === Set[(Int, Seq[Int])](
       (1, Seq[Int](5)),
       (2, Seq[Int](6)),
       (3, Seq[Int](7)),
-      (nullInt, Seq[Int](8))
+      (4, Seq[Int](nullInt)),
+      (nullInt, Seq[Int](nullInt, 8))
     ))
 
-    // Null values
-    map.insert(4, nullInt)
-    map.insert(nullInt, nullInt)
-    assert(map.size === 5)
-    val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
-    assert(result === Set[(Int, Set[Int])](
-      (1, Set[Int](5)),
-      (2, Set[Int](6)),
-      (3, Set[Int](7)),
-      (4, Set[Int](nullInt)),
-      (nullInt, Set[Int](nullInt, 8))
-    ))
     sc.stop()
   }
 
@@ -344,7 +334,9 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
     val conf = createSparkConf(loadDefaults = true)
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
-    val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
+    val context = MemoryTestingUtils.fakeTaskContext(sc.env)
+    val map =
+      new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _, context = context)
 
     // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
     // problems if the map fails to group together the objects with the same code (SPARK-2043).


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org