You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/10/09 06:45:08 UTC

spark git commit: [SPARK-10956] Common MemoryManager interface for storage and execution

Repository: spark
Updated Branches:
  refs/heads/master 098412900 -> 67fbecbf3


[SPARK-10956] Common MemoryManager interface for storage and execution

This patch introduces a `MemoryManager` that is the central arbiter of how much memory to grant to storage and execution. This patch is primarily concerned only with refactoring while preserving the existing behavior as much as possible.

This is the first step away from the existing rigid separation of storage and execution memory, which has several major drawbacks discussed on the [issue](https://issues.apache.org/jira/browse/SPARK-10956). It is the precursor of a series of patches that will attempt to address those drawbacks.

Author: Andrew Or <an...@databricks.com>
Author: Josh Rosen <jo...@databricks.com>
Author: andrewor14 <an...@databricks.com>

Closes #9000 from andrewor14/memory-manager.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/67fbecbf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/67fbecbf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/67fbecbf

Branch: refs/heads/master
Commit: 67fbecbf32fced87d3accd2618fef2af9f44fae2
Parents: 0984129
Author: Andrew Or <an...@databricks.com>
Authored: Thu Oct 8 21:44:59 2015 -0700
Committer: Josh Rosen <jo...@databricks.com>
Committed: Thu Oct 8 21:44:59 2015 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/SparkEnv.scala  |  11 +-
 .../org/apache/spark/memory/MemoryManager.scala | 117 ++++++++
 .../spark/memory/StaticMemoryManager.scala      | 202 ++++++++++++++
 .../spark/shuffle/ShuffleMemoryManager.scala    |  69 +++--
 .../org/apache/spark/storage/BlockManager.scala |  33 +--
 .../org/apache/spark/storage/MemoryStore.scala  | 272 +++++++++----------
 .../spark/memory/StaticMemoryManagerSuite.scala | 172 ++++++++++++
 .../storage/BlockManagerReplicationSuite.scala  |  29 +-
 .../spark/storage/BlockManagerSuite.scala       |  34 ++-
 .../execution/TestShuffleMemoryManager.scala    |  28 +-
 .../streaming/ReceivedBlockHandlerSuite.scala   |  13 +-
 11 files changed, 752 insertions(+), 228 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index cfde27f..df3d84a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -30,6 +30,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.python.PythonWorkerFactory
 import org.apache.spark.broadcast.BroadcastManager
 import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.memory.{MemoryManager, StaticMemoryManager}
 import org.apache.spark.network.BlockTransferService
 import org.apache.spark.network.netty.NettyBlockTransferService
 import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv}
@@ -69,6 +70,8 @@ class SparkEnv (
     val httpFileServer: HttpFileServer,
     val sparkFilesDir: String,
     val metricsSystem: MetricsSystem,
+    // TODO: unify these *MemoryManager classes (SPARK-10984)
+    val memoryManager: MemoryManager,
     val shuffleMemoryManager: ShuffleMemoryManager,
     val executorMemoryManager: ExecutorMemoryManager,
     val outputCommitCoordinator: OutputCommitCoordinator,
@@ -332,7 +335,8 @@ object SparkEnv extends Logging {
     val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
     val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
 
-    val shuffleMemoryManager = ShuffleMemoryManager.create(conf, numUsableCores)
+    val memoryManager = new StaticMemoryManager(conf)
+    val shuffleMemoryManager = ShuffleMemoryManager.create(conf, memoryManager, numUsableCores)
 
     val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores)
 
@@ -343,8 +347,8 @@ object SparkEnv extends Logging {
 
     // NB: blockManager is not valid until initialize() is called later.
     val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
-      serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager,
-      numUsableCores)
+      serializer, conf, memoryManager, mapOutputTracker, shuffleManager,
+      blockTransferService, securityManager, numUsableCores)
 
     val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
 
@@ -417,6 +421,7 @@ object SparkEnv extends Logging {
       httpFileServer,
       sparkFilesDir,
       metricsSystem,
+      memoryManager,
       shuffleMemoryManager,
       executorMemoryManager,
       outputCommitCoordinator,

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
new file mode 100644
index 0000000..4bf73b6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import scala.collection.mutable
+
+import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
+
+
+/**
+ * An abstract memory manager that enforces how memory is shared between execution and storage.
+ *
+ * In this context, execution memory refers to that used for computation in shuffles, joins,
+ * sorts and aggregations, while storage memory refers to that used for caching and propagating
+ * internal data across the cluster. There exists one of these per JVM.
+ */
+private[spark] abstract class MemoryManager {
+
+  // The memory store used to evict cached blocks
+  private var _memoryStore: MemoryStore = _
+  protected def memoryStore: MemoryStore = {
+    if (_memoryStore == null) {
+      throw new IllegalArgumentException("memory store not initialized yet")
+    }
+    _memoryStore
+  }
+
+  /**
+   * Set the [[MemoryStore]] used by this manager to evict cached blocks.
+   * This must be set after construction due to initialization ordering constraints.
+   */
+  def setMemoryStore(store: MemoryStore): Unit = {
+    _memoryStore = store
+  }
+
+  /**
+   * Acquire N bytes of memory for execution.
+   * @return number of bytes successfully granted (<= N).
+   */
+  def acquireExecutionMemory(numBytes: Long): Long
+
+  /**
+   * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
+   * Blocks evicted in the process, if any, are added to `evictedBlocks`.
+   * @return whether all N bytes were successfully granted.
+   */
+  def acquireStorageMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean
+
+  /**
+   * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary.
+   * Blocks evicted in the process, if any, are added to `evictedBlocks`.
+   * @return whether all N bytes were successfully granted.
+   */
+  def acquireUnrollMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean
+
+  /**
+   * Release N bytes of execution memory.
+   */
+  def releaseExecutionMemory(numBytes: Long): Unit
+
+  /**
+   * Release N bytes of storage memory.
+   */
+  def releaseStorageMemory(numBytes: Long): Unit
+
+  /**
+   * Release all storage memory acquired.
+   */
+  def releaseStorageMemory(): Unit
+
+  /**
+   * Release N bytes of unroll memory.
+   */
+  def releaseUnrollMemory(numBytes: Long): Unit
+
+  /**
+   * Total available memory for execution, in bytes.
+   */
+  def maxExecutionMemory: Long
+
+  /**
+   * Total available memory for storage, in bytes.
+   */
+  def maxStorageMemory: Long
+
+  /**
+   * Execution memory currently in use, in bytes.
+   */
+  def executionMemoryUsed: Long
+
+  /**
+   * Storage memory currently in use, in bytes.
+   */
+  def storageMemoryUsed: Long
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
new file mode 100644
index 0000000..150445e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import scala.collection.mutable
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.storage.{BlockId, BlockStatus}
+
+
+/**
+ * A [[MemoryManager]] that statically partitions the heap space into disjoint regions.
+ *
+ * The sizes of the execution and storage regions are determined through
+ * `spark.shuffle.memoryFraction` and `spark.storage.memoryFraction` respectively. The two
+ * regions are cleanly separated such that neither usage can borrow memory from the other.
+ */
+private[spark] class StaticMemoryManager(
+    conf: SparkConf,
+    override val maxExecutionMemory: Long,
+    override val maxStorageMemory: Long)
+  extends MemoryManager with Logging {
+
+  // Max number of bytes worth of blocks to evict when unrolling
+  private val maxMemoryToEvictForUnroll: Long = {
+    (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong
+  }
+
+  // Amount of execution / storage memory in use
+  // Accesses must be synchronized on `this`
+  private var _executionMemoryUsed: Long = 0
+  private var _storageMemoryUsed: Long = 0
+
+  def this(conf: SparkConf) {
+    this(
+      conf,
+      StaticMemoryManager.getMaxExecutionMemory(conf),
+      StaticMemoryManager.getMaxStorageMemory(conf))
+  }
+
+  /**
+   * Acquire N bytes of memory for execution.
+   * @return number of bytes successfully granted (<= N).
+   */
+  override def acquireExecutionMemory(numBytes: Long): Long = synchronized {
+    assert(_executionMemoryUsed <= maxExecutionMemory)
+    val bytesToGrant = math.min(numBytes, maxExecutionMemory - _executionMemoryUsed)
+    _executionMemoryUsed += bytesToGrant
+    bytesToGrant
+  }
+
+  /**
+   * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
+   * Blocks evicted in the process, if any, are added to `evictedBlocks`.
+   * @return whether all N bytes were successfully granted.
+   */
+  override def acquireStorageMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
+    acquireStorageMemory(blockId, numBytes, numBytes, evictedBlocks)
+  }
+
+  /**
+   * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary.
+   *
+   * This evicts at most M bytes worth of existing blocks, where M is a fraction of the storage
+   * space specified by `spark.storage.unrollFraction`. Blocks evicted in the process, if any,
+   * are added to `evictedBlocks`.
+   *
+   * @return whether all N bytes were successfully granted.
+   */
+  override def acquireUnrollMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
+    val currentUnrollMemory = memoryStore.currentUnrollMemory
+    val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory)
+    val numBytesToFree = math.min(numBytes, maxNumBytesToFree)
+    acquireStorageMemory(blockId, numBytes, numBytesToFree, evictedBlocks)
+  }
+
+  /**
+   * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary.
+   *
+   * @param blockId the ID of the block we are acquiring storage memory for
+   * @param numBytesToAcquire the size of this block
+   * @param numBytesToFree the size of space to be freed through evicting blocks
+   * @param evictedBlocks a holder for blocks evicted in the process
+   * @return whether all N bytes were successfully granted.
+   */
+  private def acquireStorageMemory(
+      blockId: BlockId,
+      numBytesToAcquire: Long,
+      numBytesToFree: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
+    // Note: Keep this outside synchronized block to avoid potential deadlocks!
+    memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks)
+    synchronized {
+      assert(_storageMemoryUsed <= maxStorageMemory)
+      val enoughMemory = _storageMemoryUsed + numBytesToAcquire <= maxStorageMemory
+      if (enoughMemory) {
+        _storageMemoryUsed += numBytesToAcquire
+      }
+      enoughMemory
+    }
+  }
+
+  /**
+   * Release N bytes of execution memory.
+   */
+  override def releaseExecutionMemory(numBytes: Long): Unit = synchronized {
+    if (numBytes > _executionMemoryUsed) {
+      logWarning(s"Attempted to release $numBytes bytes of execution " +
+        s"memory when we only have ${_executionMemoryUsed} bytes")
+      _executionMemoryUsed = 0
+    } else {
+      _executionMemoryUsed -= numBytes
+    }
+  }
+
+  /**
+   * Release N bytes of storage memory.
+   */
+  override def releaseStorageMemory(numBytes: Long): Unit = synchronized {
+    if (numBytes > _storageMemoryUsed) {
+      logWarning(s"Attempted to release $numBytes bytes of storage " +
+        s"memory when we only have ${_storageMemoryUsed} bytes")
+      _storageMemoryUsed = 0
+    } else {
+      _storageMemoryUsed -= numBytes
+    }
+  }
+
+  /**
+   * Release all storage memory acquired.
+   */
+  override def releaseStorageMemory(): Unit = synchronized {
+    _storageMemoryUsed = 0
+  }
+
+  /**
+   * Release N bytes of unroll memory.
+   */
+  override def releaseUnrollMemory(numBytes: Long): Unit = {
+    releaseStorageMemory(numBytes)
+  }
+
+  /**
+   * Amount of execution memory currently in use, in bytes.
+   */
+  override def executionMemoryUsed: Long = synchronized {
+    _executionMemoryUsed
+  }
+
+  /**
+   * Amount of storage memory currently in use, in bytes.
+   */
+  override def storageMemoryUsed: Long = synchronized {
+    _storageMemoryUsed
+  }
+
+}
+
+
+private[spark] object StaticMemoryManager {
+
+  /**
+   * Return the total amount of memory available for the storage region, in bytes.
+   */
+  private def getMaxStorageMemory(conf: SparkConf): Long = {
+    val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6)
+    val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9)
+    (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+  }
+
+
+  /**
+   * Return the total amount of memory available for the execution region, in bytes.
+   */
+  private def getMaxExecutionMemory(conf: SparkConf): Long = {
+    val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
+    val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
+    (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index 9839c76..bb64bb3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -21,8 +21,9 @@ import scala.collection.mutable
 
 import com.google.common.annotations.VisibleForTesting
 
+import org.apache.spark._
+import org.apache.spark.memory.{StaticMemoryManager, MemoryManager}
 import org.apache.spark.unsafe.array.ByteArrayMethods
-import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext}
 
 /**
  * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling
@@ -40,16 +41,17 @@ import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext}
  *
  * Use `ShuffleMemoryManager.create()` factory method to create a new instance.
  *
- * @param maxMemory total amount of memory available for execution, in bytes.
+ * @param memoryManager the interface through which this manager acquires execution memory
  * @param pageSizeBytes number of bytes for each page, by default.
  */
 private[spark]
 class ShuffleMemoryManager protected (
-    val maxMemory: Long,
+    memoryManager: MemoryManager,
     val pageSizeBytes: Long)
   extends Logging {
 
   private val taskMemory = new mutable.HashMap[Long, Long]()  // taskAttemptId -> memory bytes
+  private val maxMemory = memoryManager.maxExecutionMemory
 
   private def currentTaskAttemptId(): Long = {
     // In case this is called on the driver, return an invalid task attempt id.
@@ -71,7 +73,7 @@ class ShuffleMemoryManager protected (
     // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
     if (!taskMemory.contains(taskAttemptId)) {
       taskMemory(taskAttemptId) = 0L
-      notifyAll()  // Will later cause waiting tasks to wake up and check numThreads again
+      notifyAll()  // Will later cause waiting tasks to wake up and check numTasks again
     }
 
     // Keep looping until we're either sure that we don't want to grant this request (because this
@@ -85,46 +87,57 @@ class ShuffleMemoryManager protected (
       // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
       // don't let it be negative
       val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem))
+      // Only give it as much memory as is free, which might be none if it reached 1 / numTasks
+      val toGrant = math.min(maxToGrant, freeMemory)
 
       if (curMem < maxMemory / (2 * numActiveTasks)) {
         // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
         // if we can't give it this much now, wait for other tasks to free up memory
         // (this happens if older tasks allocated lots of memory before N grew)
         if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) {
-          val toGrant = math.min(maxToGrant, freeMemory)
-          taskMemory(taskAttemptId) += toGrant
-          return toGrant
+          return acquire(toGrant)
         } else {
           logInfo(
             s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
           wait()
         }
       } else {
-        // Only give it as much memory as is free, which might be none if it reached 1 / numThreads
-        val toGrant = math.min(maxToGrant, freeMemory)
-        taskMemory(taskAttemptId) += toGrant
-        return toGrant
+        return acquire(toGrant)
       }
     }
     0L  // Never reached
   }
 
+  /**
+   * Acquire N bytes of execution memory from the memory manager for the current task.
+   * @return number of bytes actually acquired (<= N).
+   */
+  private def acquire(numBytes: Long): Long = synchronized {
+    val taskAttemptId = currentTaskAttemptId()
+    val acquired = memoryManager.acquireExecutionMemory(numBytes)
+    taskMemory(taskAttemptId) += acquired
+    acquired
+  }
+
   /** Release numBytes bytes for the current task. */
   def release(numBytes: Long): Unit = synchronized {
     val taskAttemptId = currentTaskAttemptId()
     val curMem = taskMemory.getOrElse(taskAttemptId, 0L)
     if (curMem < numBytes) {
       throw new SparkException(
-        s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}")
+        s"Internal error: release called on $numBytes bytes but task only has $curMem")
     }
     taskMemory(taskAttemptId) -= numBytes
+    memoryManager.releaseExecutionMemory(numBytes)
     notifyAll()  // Notify waiters who locked "this" in tryToAcquire that memory has been freed
   }
 
   /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */
   def releaseMemoryForThisTask(): Unit = synchronized {
     val taskAttemptId = currentTaskAttemptId()
-    taskMemory.remove(taskAttemptId)
+    taskMemory.remove(taskAttemptId).foreach { numBytes =>
+      memoryManager.releaseExecutionMemory(numBytes)
+    }
     notifyAll()  // Notify waiters who locked "this" in tryToAcquire that memory has been freed
   }
 
@@ -138,30 +151,28 @@ class ShuffleMemoryManager protected (
 
 private[spark] object ShuffleMemoryManager {
 
-  def create(conf: SparkConf, numCores: Int): ShuffleMemoryManager = {
-    val maxMemory = ShuffleMemoryManager.getMaxMemory(conf)
+  def create(
+      conf: SparkConf,
+      memoryManager: MemoryManager,
+      numCores: Int): ShuffleMemoryManager = {
+    val maxMemory = memoryManager.maxExecutionMemory
     val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores)
-    new ShuffleMemoryManager(maxMemory, pageSize)
+    new ShuffleMemoryManager(memoryManager, pageSize)
   }
 
+  /**
+   * Create a dummy [[ShuffleMemoryManager]] with the specified capacity and page size.
+   */
   def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = {
-    new ShuffleMemoryManager(maxMemory, pageSizeBytes)
+    val conf = new SparkConf
+    val memoryManager = new StaticMemoryManager(
+      conf, maxExecutionMemory = maxMemory, maxStorageMemory = Long.MaxValue)
+    new ShuffleMemoryManager(memoryManager, pageSizeBytes)
   }
 
   @VisibleForTesting
   def createForTesting(maxMemory: Long): ShuffleMemoryManager = {
-    new ShuffleMemoryManager(maxMemory, 4 * 1024 * 1024)
-  }
-
-  /**
-   * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
-   * of the memory pool and a safety factor since collections can sometimes grow bigger than
-   * the size we target before we estimate their sizes again.
-   */
-  private def getMaxMemory(conf: SparkConf): Long = {
-    val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
-    val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
-    (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+    create(maxMemory, 4 * 1024 * 1024)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 47bd2ef..9f5bd2a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -31,6 +31,7 @@ import sun.nio.ch.DirectBuffer
 import org.apache.spark._
 import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
 import org.apache.spark.io.CompressionCodec
+import org.apache.spark.memory.MemoryManager
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
 import org.apache.spark.network.netty.SparkTransportConf
@@ -64,8 +65,8 @@ private[spark] class BlockManager(
     rpcEnv: RpcEnv,
     val master: BlockManagerMaster,
     defaultSerializer: Serializer,
-    maxMemory: Long,
     val conf: SparkConf,
+    memoryManager: MemoryManager,
     mapOutputTracker: MapOutputTracker,
     shuffleManager: ShuffleManager,
     blockTransferService: BlockTransferService,
@@ -82,12 +83,15 @@ private[spark] class BlockManager(
 
   // Actual storage of where blocks are kept
   private var externalBlockStoreInitialized = false
-  private[spark] val memoryStore = new MemoryStore(this, maxMemory)
+  private[spark] val memoryStore = new MemoryStore(this, memoryManager)
   private[spark] val diskStore = new DiskStore(this, diskBlockManager)
   private[spark] lazy val externalBlockStore: ExternalBlockStore = {
     externalBlockStoreInitialized = true
     new ExternalBlockStore(this, executorId)
   }
+  memoryManager.setMemoryStore(memoryStore)
+
+  private val maxMemory = memoryManager.maxStorageMemory
 
   private[spark]
   val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
@@ -158,24 +162,6 @@ private[spark] class BlockManager(
   private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
 
   /**
-   * Construct a BlockManager with a memory limit set based on system properties.
-   */
-  def this(
-      execId: String,
-      rpcEnv: RpcEnv,
-      master: BlockManagerMaster,
-      serializer: Serializer,
-      conf: SparkConf,
-      mapOutputTracker: MapOutputTracker,
-      shuffleManager: ShuffleManager,
-      blockTransferService: BlockTransferService,
-      securityManager: SecurityManager,
-      numUsableCores: Int) = {
-    this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf),
-      conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores)
-  }
-
-  /**
    * Initializes the BlockManager with the given appId. This is not performed in the constructor as
    * the appId may not be known at BlockManager instantiation time (in particular for the driver,
    * where it is only learned after registration with the TaskScheduler).
@@ -1267,13 +1253,6 @@ private[spark] class BlockManager(
 private[spark] object BlockManager extends Logging {
   private val ID_GENERATOR = new IdGenerator
 
-  /** Return the total amount of storage memory available. */
-  private def getMaxMemory(conf: SparkConf): Long = {
-    val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6)
-    val safetyFraction = conf.getDouble("spark.storage.safetyFraction", 0.9)
-    (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
-  }
-
   /**
    * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
    * might cause errors if one attempts to read from the unmapped buffer, but it's better than

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 6f27f00..35c57b9 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.TaskContext
+import org.apache.spark.memory.MemoryManager
 import org.apache.spark.util.{SizeEstimator, Utils}
 import org.apache.spark.util.collection.SizeTrackingVector
 
@@ -33,13 +34,12 @@ private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean)
  * Stores blocks in memory, either as Arrays of deserialized Java objects or as
  * serialized ByteBuffers.
  */
-private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
+private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: MemoryManager)
   extends BlockStore(blockManager) {
 
   private val conf = blockManager.conf
   private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true)
-
-  @volatile private var currentMemory = 0L
+  private val maxMemory = memoryManager.maxStorageMemory
 
   // Ensure only one thread is putting, and if necessary, dropping blocks at any given time
   private val accountingLock = new Object
@@ -56,15 +56,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
   // memory (SPARK-4777).
   private val pendingUnrollMemoryMap = mutable.HashMap[Long, Long]()
 
-  /**
-   * The amount of space ensured for unrolling values in memory, shared across all cores.
-   * This space is not reserved in advance, but allocated dynamically by dropping existing blocks.
-   */
-  private val maxUnrollMemory: Long = {
-    val unrollFraction = conf.getDouble("spark.storage.unrollFraction", 0.2)
-    (maxMemory * unrollFraction).toLong
-  }
-
   // Initial memory to request before unrolling any block
   private val unrollMemoryThreshold: Long =
     conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024)
@@ -77,8 +68,14 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
 
   logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory)))
 
-  /** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */
-  def freeMemory: Long = maxMemory - currentMemory
+  /** Total storage memory used including unroll memory, in bytes. */
+  private def memoryUsed: Long = memoryManager.storageMemoryUsed
+
+  /**
+   * Amount of storage memory, in bytes, used for caching blocks.
+   * This does not include memory used for unrolling.
+   */
+  private def blocksMemoryUsed: Long = memoryUsed - currentUnrollMemory
 
   override def getSize(blockId: BlockId): Long = {
     entries.synchronized {
@@ -94,8 +91,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
       val values = blockManager.dataDeserialize(blockId, bytes)
       putIterator(blockId, values, level, returnValues = true)
     } else {
-      val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false)
-      PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks)
+      val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+      tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks)
+      PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks)
     }
   }
 
@@ -108,15 +106,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
   def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = {
     // Work on a duplicate - since the original input might be used elsewhere.
     lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer]
-    val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false)
+    val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+    val putSuccess = tryToPut(blockId, () => bytes, size, deserialized = false, droppedBlocks)
     val data =
-      if (putAttempt.success) {
+      if (putSuccess) {
         assert(bytes.limit == size)
         Right(bytes.duplicate())
       } else {
         null
       }
-    PutResult(size, data, putAttempt.droppedBlocks)
+    PutResult(size, data, droppedBlocks)
   }
 
   override def putArray(
@@ -124,14 +123,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
       values: Array[Any],
       level: StorageLevel,
       returnValues: Boolean): PutResult = {
+    val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
     if (level.deserialized) {
       val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef])
-      val putAttempt = tryToPut(blockId, values, sizeEstimate, deserialized = true)
-      PutResult(sizeEstimate, Left(values.iterator), putAttempt.droppedBlocks)
+      tryToPut(blockId, values, sizeEstimate, deserialized = true, droppedBlocks)
+      PutResult(sizeEstimate, Left(values.iterator), droppedBlocks)
     } else {
       val bytes = blockManager.dataSerialize(blockId, values.iterator)
-      val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false)
-      PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks)
+      tryToPut(blockId, bytes, bytes.limit, deserialized = false, droppedBlocks)
+      PutResult(bytes.limit(), Right(bytes.duplicate()), droppedBlocks)
     }
   }
 
@@ -209,23 +209,22 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
   }
 
   override def remove(blockId: BlockId): Boolean = {
-    entries.synchronized {
-      val entry = entries.remove(blockId)
-      if (entry != null) {
-        currentMemory -= entry.size
-        logDebug(s"Block $blockId of size ${entry.size} dropped from memory (free $freeMemory)")
-        true
-      } else {
-        false
-      }
+    val entry = entries.synchronized { entries.remove(blockId) }
+    if (entry != null) {
+      memoryManager.releaseStorageMemory(entry.size)
+      logDebug(s"Block $blockId of size ${entry.size} dropped " +
+        s"from memory (free ${maxMemory - blocksMemoryUsed})")
+      true
+    } else {
+      false
     }
   }
 
   override def clear() {
     entries.synchronized {
       entries.clear()
-      currentMemory = 0
     }
+    memoryManager.releaseStorageMemory()
     logInfo("MemoryStore cleared")
   }
 
@@ -265,7 +264,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
     var vector = new SizeTrackingVector[Any]
 
     // Request enough memory to begin unrolling
-    keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold)
+    keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, droppedBlocks)
 
     if (!keepUnrolling) {
       logWarning(s"Failed to reserve initial memory threshold of " +
@@ -281,20 +280,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
           val currentSize = vector.estimateSize()
           if (currentSize >= memoryThreshold) {
             val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong
-            // Hold the accounting lock, in case another thread concurrently puts a block that
-            // takes up the unrolling space we just ensured here
-            accountingLock.synchronized {
-              if (!reserveUnrollMemoryForThisTask(amountToRequest)) {
-                // If the first request is not granted, try again after ensuring free space
-                // If there is still not enough space, give up and drop the partition
-                val spaceToEnsure = maxUnrollMemory - currentUnrollMemory
-                if (spaceToEnsure > 0) {
-                  val result = ensureFreeSpace(blockId, spaceToEnsure)
-                  droppedBlocks ++= result.droppedBlocks
-                }
-                keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest)
-              }
-            }
+            keepUnrolling = reserveUnrollMemoryForThisTask(
+              blockId, amountToRequest, droppedBlocks)
             // New threshold is currentSize * memoryGrowthFactor
             memoryThreshold += amountToRequest
           }
@@ -317,10 +304,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
       // Otherwise, if we return an iterator, we release the memory reserved here
       // later when the task finishes.
       if (keepUnrolling) {
+        val taskAttemptId = currentTaskAttemptId()
         accountingLock.synchronized {
-          val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved
-          releaseUnrollMemoryForThisTask(amountToRelease)
-          reservePendingUnrollMemoryForThisTask(amountToRelease)
+          // Here, we transfer memory from unroll to pending unroll because we expect to cache this
+          // block in `tryToPut`. We do not release and re-acquire memory from the MemoryManager in
+          // order to avoid race conditions where another component steals the memory that we're
+          // trying to transfer.
+          val amountToTransferToPending = currentUnrollMemoryForThisTask - previousMemoryReserved
+          unrollMemoryMap(taskAttemptId) -= amountToTransferToPending
+          pendingUnrollMemoryMap(taskAttemptId) =
+            pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + amountToTransferToPending
         }
       }
     }
@@ -337,8 +330,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
       blockId: BlockId,
       value: Any,
       size: Long,
-      deserialized: Boolean): ResultWithDroppedBlocks = {
-    tryToPut(blockId, () => value, size, deserialized)
+      deserialized: Boolean,
+      droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
+    tryToPut(blockId, () => value, size, deserialized, droppedBlocks)
   }
 
   /**
@@ -354,13 +348,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
    * blocks to free memory for one block, another thread may use up the freed space for
    * another block.
    *
-   * Return whether put was successful, along with the blocks dropped in the process.
+   * All blocks evicted in the process, if any, will be added to `droppedBlocks`.
+   *
+   * @return whether put was successful.
    */
   private def tryToPut(
       blockId: BlockId,
       value: () => Any,
       size: Long,
-      deserialized: Boolean): ResultWithDroppedBlocks = {
+      deserialized: Boolean,
+      droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
 
     /* TODO: Its possible to optimize the locking by locking entries only when selecting blocks
      * to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has
@@ -368,24 +365,27 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
      * for freeing up more space for another block that needs to be put. Only then the actually
      * dropping of blocks (and writing to disk if necessary) can proceed in parallel. */
 
-    var putSuccess = false
-    val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
-
     accountingLock.synchronized {
-      val freeSpaceResult = ensureFreeSpace(blockId, size)
-      val enoughFreeSpace = freeSpaceResult.success
-      droppedBlocks ++= freeSpaceResult.droppedBlocks
-
-      if (enoughFreeSpace) {
+      // Note: if we have previously unrolled this block successfully, then pending unroll
+      // memory should be non-zero. This is the amount that we already reserved during the
+      // unrolling process. In this case, we can just reuse this space to cache our block.
+      //
+      // Note: the StaticMemoryManager counts unroll memory as storage memory. Here, the
+      // synchronization on `accountingLock` guarantees that the release of unroll memory and
+      // acquisition of storage memory happens atomically. However, if storage memory is acquired
+      // outside of MemoryStore or if unroll memory is counted as execution memory, then we will
+      // have to revisit this assumption. See SPARK-10983 for more context.
+      releasePendingUnrollMemoryForThisTask()
+      val enoughMemory = memoryManager.acquireStorageMemory(blockId, size, droppedBlocks)
+      if (enoughMemory) {
+        // We acquired enough memory for the block, so go ahead and put it
         val entry = new MemoryEntry(value(), size, deserialized)
         entries.synchronized {
           entries.put(blockId, entry)
-          currentMemory += size
         }
         val valuesOrBytes = if (deserialized) "values" else "bytes"
         logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format(
-          blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(freeMemory)))
-        putSuccess = true
+          blockId, valuesOrBytes, Utils.bytesToString(size), Utils.bytesToString(blocksMemoryUsed)))
       } else {
         // Tell the block manager that we couldn't put it in memory so that it can drop it to
         // disk if the block allows disk storage.
@@ -397,10 +397,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
         val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data)
         droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
       }
-      // Release the unroll memory used because we no longer need the underlying Array
-      releasePendingUnrollMemoryForThisTask()
+      enoughMemory
     }
-    ResultWithDroppedBlocks(putSuccess, droppedBlocks)
   }
 
   /**
@@ -409,40 +407,42 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
    * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
    * don't fit into memory that we want to avoid).
    *
-   * Assume that `accountingLock` is held by the caller to ensure only one thread is dropping
-   * blocks. Otherwise, the freed space may fill up before the caller puts in their new value.
-   *
-   * Return whether there is enough free space, along with the blocks dropped in the process.
+   * @param blockId the ID of the block we are freeing space for
+   * @param space the size of this block
+   * @param droppedBlocks a holder for blocks evicted in the process
+   * @return whether there is enough free space.
    */
-  private def ensureFreeSpace(
-      blockIdToAdd: BlockId,
-      space: Long): ResultWithDroppedBlocks = {
-    logInfo(s"ensureFreeSpace($space) called with curMem=$currentMemory, maxMem=$maxMemory")
-
-    val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+  private[spark] def ensureFreeSpace(
+      blockId: BlockId,
+      space: Long,
+      droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
+    accountingLock.synchronized {
+      val freeMemory = maxMemory - memoryUsed
+      val rddToAdd = getRddId(blockId)
+      val selectedBlocks = new ArrayBuffer[BlockId]
+      var selectedMemory = 0L
 
-    if (space > maxMemory) {
-      logInfo(s"Will not store $blockIdToAdd as it is larger than our memory limit")
-      return ResultWithDroppedBlocks(success = false, droppedBlocks)
-    }
+      logInfo(s"Ensuring $space bytes of free space for block $blockId " +
+        s"(free: $freeMemory, max: $maxMemory)")
 
-    // Take into account the amount of memory currently occupied by unrolling blocks
-    // and minus the pending unroll memory for that block on current thread.
-    val taskAttemptId = currentTaskAttemptId()
-    val actualFreeMemory = freeMemory - currentUnrollMemory +
-      pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L)
+      // Fail fast if the block simply won't fit
+      if (space > maxMemory) {
+        logInfo(s"Will not store $blockId as the required space " +
+          s"($space bytes) than our memory limit ($maxMemory bytes)")
+        return false
+      }
 
-    if (actualFreeMemory < space) {
-      val rddToAdd = getRddId(blockIdToAdd)
-      val selectedBlocks = new ArrayBuffer[BlockId]
-      var selectedMemory = 0L
+      // No need to evict anything if there is already enough free space
+      if (freeMemory >= space) {
+        return true
+      }
 
       // This is synchronized to ensure that the set of entries is not changed
       // (because of getValue or getBytes) while traversing the iterator, as that
       // can lead to exceptions.
       entries.synchronized {
         val iterator = entries.entrySet().iterator()
-        while (actualFreeMemory + selectedMemory < space && iterator.hasNext) {
+        while (freeMemory + selectedMemory < space && iterator.hasNext) {
           val pair = iterator.next()
           val blockId = pair.getKey
           if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) {
@@ -452,7 +452,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
         }
       }
 
-      if (actualFreeMemory + selectedMemory >= space) {
+      if (freeMemory + selectedMemory >= space) {
         logInfo(s"${selectedBlocks.size} blocks selected for dropping")
         for (blockId <- selectedBlocks) {
           val entry = entries.synchronized { entries.get(blockId) }
@@ -469,14 +469,13 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
             droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
           }
         }
-        return ResultWithDroppedBlocks(success = true, droppedBlocks)
+        true
       } else {
-        logInfo(s"Will not store $blockIdToAdd as it would require dropping another block " +
+        logInfo(s"Will not store $blockId as it would require dropping another block " +
           "from the same RDD")
-        return ResultWithDroppedBlocks(success = false, droppedBlocks)
+        false
       }
     }
-    ResultWithDroppedBlocks(success = true, droppedBlocks)
   }
 
   override def contains(blockId: BlockId): Boolean = {
@@ -489,17 +488,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
   }
 
   /**
-   * Reserve additional memory for unrolling blocks used by this task.
-   * Return whether the request is granted.
+   * Reserve memory for unrolling the given block for this task.
+   * @return whether the request is granted.
    */
-  def reserveUnrollMemoryForThisTask(memory: Long): Boolean = {
+  def reserveUnrollMemoryForThisTask(
+      blockId: BlockId,
+      memory: Long,
+      droppedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = {
     accountingLock.synchronized {
-      val granted = freeMemory > currentUnrollMemory + memory
-      if (granted) {
+      // Note: all acquisitions of unroll memory must be synchronized on `accountingLock`
+      val success = memoryManager.acquireUnrollMemory(blockId, memory, droppedBlocks)
+      if (success) {
         val taskAttemptId = currentTaskAttemptId()
         unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
       }
-      granted
+      success
     }
   }
 
@@ -507,40 +510,38 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
    * Release memory used by this task for unrolling blocks.
    * If the amount is not specified, remove the current task's allocation altogether.
    */
-  def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = {
+  def releaseUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = {
     val taskAttemptId = currentTaskAttemptId()
     accountingLock.synchronized {
-      if (memory < 0) {
-        unrollMemoryMap.remove(taskAttemptId)
-      } else {
-        unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory
-        // If this task claims no more unroll memory, release it completely
-        if (unrollMemoryMap(taskAttemptId) <= 0) {
-          unrollMemoryMap.remove(taskAttemptId)
+      if (unrollMemoryMap.contains(taskAttemptId)) {
+        val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId))
+        if (memoryToRelease > 0) {
+          unrollMemoryMap(taskAttemptId) -= memoryToRelease
+          if (unrollMemoryMap(taskAttemptId) == 0) {
+            unrollMemoryMap.remove(taskAttemptId)
+          }
+          memoryManager.releaseUnrollMemory(memoryToRelease)
         }
       }
     }
   }
 
   /**
-   * Reserve the unroll memory of current unroll successful block used by this task
-   * until actually put the block into memory entry.
-   */
-  def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = {
-    val taskAttemptId = currentTaskAttemptId()
-    accountingLock.synchronized {
-       pendingUnrollMemoryMap(taskAttemptId) =
-         pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
-    }
-  }
-
-  /**
    * Release pending unroll memory of current unroll successful block used by this task
    */
-  def releasePendingUnrollMemoryForThisTask(): Unit = {
+  def releasePendingUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = {
     val taskAttemptId = currentTaskAttemptId()
     accountingLock.synchronized {
-      pendingUnrollMemoryMap.remove(taskAttemptId)
+      if (pendingUnrollMemoryMap.contains(taskAttemptId)) {
+        val memoryToRelease = math.min(memory, pendingUnrollMemoryMap(taskAttemptId))
+        if (memoryToRelease > 0) {
+          pendingUnrollMemoryMap(taskAttemptId) -= memoryToRelease
+          if (pendingUnrollMemoryMap(taskAttemptId) == 0) {
+            pendingUnrollMemoryMap.remove(taskAttemptId)
+          }
+          memoryManager.releaseUnrollMemory(memoryToRelease)
+        }
+      }
     }
   }
 
@@ -561,19 +562,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
   /**
    * Return the number of tasks currently unrolling blocks.
    */
-  def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size }
+  private def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size }
 
   /**
    * Log information about current memory usage.
    */
-  def logMemoryUsage(): Unit = {
-    val blocksMemory = currentMemory
-    val unrollMemory = currentUnrollMemory
-    val totalMemory = blocksMemory + unrollMemory
+  private def logMemoryUsage(): Unit = {
     logInfo(
-      s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " +
-      s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " +
-      s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " +
+      s"Memory use = ${Utils.bytesToString(blocksMemoryUsed)} (blocks) + " +
+      s"${Utils.bytesToString(currentUnrollMemory)} (scratch space shared across " +
+      s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(memoryUsed)}. " +
       s"Storage limit = ${Utils.bytesToString(maxMemory)}."
     )
   }
@@ -584,7 +582,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
    * @param blockId ID of the block we are trying to unroll.
    * @param finalVectorSize Final size of the vector before unrolling failed.
    */
-  def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = {
+  private def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = {
     logWarning(
       s"Not enough space to cache $blockId in memory! " +
       s"(computed ${Utils.bytesToString(finalVectorSize)} so far)"
@@ -592,7 +590,3 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
     logMemoryUsage()
   }
 }
-
-private[spark] case class ResultWithDroppedBlocks(
-    success: Boolean,
-    droppedBlocks: Seq[(BlockId, BlockStatus)])

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
new file mode 100644
index 0000000..c436a8b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.Mockito.{mock, reset, verify, when}
+import org.mockito.Matchers.{any, eq => meq}
+
+import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId}
+import org.apache.spark.{SparkConf, SparkFunSuite}
+
+
+class StaticMemoryManagerSuite extends SparkFunSuite {
+  private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4")
+
+  test("basic execution memory") {
+    val maxExecutionMem = 1000L
+    val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue)
+    assert(mm.executionMemoryUsed === 0L)
+    assert(mm.acquireExecutionMemory(10L) === 10L)
+    assert(mm.executionMemoryUsed === 10L)
+    assert(mm.acquireExecutionMemory(100L) === 100L)
+    // Acquire up to the max
+    assert(mm.acquireExecutionMemory(1000L) === 890L)
+    assert(mm.executionMemoryUsed === maxExecutionMem)
+    assert(mm.acquireExecutionMemory(1L) === 0L)
+    assert(mm.executionMemoryUsed === maxExecutionMem)
+    mm.releaseExecutionMemory(800L)
+    assert(mm.executionMemoryUsed === 200L)
+    // Acquire after release
+    assert(mm.acquireExecutionMemory(1L) === 1L)
+    assert(mm.executionMemoryUsed === 201L)
+    // Release beyond what was acquired
+    mm.releaseExecutionMemory(maxExecutionMem)
+    assert(mm.executionMemoryUsed === 0L)
+  }
+
+  test("basic storage memory") {
+    val maxStorageMem = 1000L
+    val dummyBlock = TestBlockId("you can see the world you brought to live")
+    val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+    val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem)
+    assert(mm.storageMemoryUsed === 0L)
+    assert(mm.acquireStorageMemory(dummyBlock, 10L, evictedBlocks))
+    // `ensureFreeSpace` should be called with the number of bytes requested
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 10L)
+    assert(mm.storageMemoryUsed === 10L)
+    assert(evictedBlocks.isEmpty)
+    assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks))
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L)
+    assert(mm.storageMemoryUsed === 110L)
+    // Acquire up to the max, not granted
+    assert(!mm.acquireStorageMemory(dummyBlock, 1000L, evictedBlocks))
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 1000L)
+    assert(mm.storageMemoryUsed === 110L)
+    assert(mm.acquireStorageMemory(dummyBlock, 890L, evictedBlocks))
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 890L)
+    assert(mm.storageMemoryUsed === 1000L)
+    assert(!mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks))
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L)
+    assert(mm.storageMemoryUsed === 1000L)
+    mm.releaseStorageMemory(800L)
+    assert(mm.storageMemoryUsed === 200L)
+    // Acquire after release
+    assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks))
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L)
+    assert(mm.storageMemoryUsed === 201L)
+    mm.releaseStorageMemory()
+    assert(mm.storageMemoryUsed === 0L)
+    assert(mm.acquireStorageMemory(dummyBlock, 1L, evictedBlocks))
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 1L)
+    assert(mm.storageMemoryUsed === 1L)
+    // Release beyond what was acquired
+    mm.releaseStorageMemory(100L)
+    assert(mm.storageMemoryUsed === 0L)
+  }
+
+  test("execution and storage isolation") {
+    val maxExecutionMem = 200L
+    val maxStorageMem = 1000L
+    val dummyBlock = TestBlockId("ain't nobody love like you do")
+    val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+    val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem)
+    // Only execution memory should increase
+    assert(mm.acquireExecutionMemory(100L) === 100L)
+    assert(mm.storageMemoryUsed === 0L)
+    assert(mm.executionMemoryUsed === 100L)
+    assert(mm.acquireExecutionMemory(1000L) === 100L)
+    assert(mm.storageMemoryUsed === 0L)
+    assert(mm.executionMemoryUsed === 200L)
+    // Only storage memory should increase
+    assert(mm.acquireStorageMemory(dummyBlock, 50L, dummyBlocks))
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 50L)
+    assert(mm.storageMemoryUsed === 50L)
+    assert(mm.executionMemoryUsed === 200L)
+    // Only execution memory should be released
+    mm.releaseExecutionMemory(133L)
+    assert(mm.storageMemoryUsed === 50L)
+    assert(mm.executionMemoryUsed === 67L)
+    // Only storage memory should be released
+    mm.releaseStorageMemory()
+    assert(mm.storageMemoryUsed === 0L)
+    assert(mm.executionMemoryUsed === 67L)
+  }
+
+  test("unroll memory") {
+    val maxStorageMem = 1000L
+    val dummyBlock = TestBlockId("lonely water")
+    val dummyBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+    val (mm, ms) = makeThings(Long.MaxValue, maxStorageMem)
+    assert(mm.acquireUnrollMemory(dummyBlock, 100L, dummyBlocks))
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 100L)
+    assert(mm.storageMemoryUsed === 100L)
+    mm.releaseUnrollMemory(40L)
+    assert(mm.storageMemoryUsed === 60L)
+    when(ms.currentUnrollMemory).thenReturn(60L)
+    assert(mm.acquireUnrollMemory(dummyBlock, 500L, dummyBlocks))
+    // `spark.storage.unrollFraction` is 0.4, so the max unroll space is 400 bytes.
+    // Since we already occupy 60 bytes, we will try to ensure only 400 - 60 = 340 bytes.
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 340L)
+    assert(mm.storageMemoryUsed === 560L)
+    when(ms.currentUnrollMemory).thenReturn(560L)
+    assert(!mm.acquireUnrollMemory(dummyBlock, 800L, dummyBlocks))
+    assert(mm.storageMemoryUsed === 560L)
+    // We already have 560 bytes > the max unroll space of 400 bytes, so no bytes are freed
+    assertEnsureFreeSpaceCalled(ms, dummyBlock, 0L)
+    // Release beyond what was acquired
+    mm.releaseUnrollMemory(maxStorageMem)
+    assert(mm.storageMemoryUsed === 0L)
+  }
+
+  /**
+   * Make a [[StaticMemoryManager]] and a [[MemoryStore]] with limited class dependencies.
+   */
+  private def makeThings(
+      maxExecutionMem: Long,
+      maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = {
+    val mm = new StaticMemoryManager(
+      conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem)
+    val ms = mock(classOf[MemoryStore])
+    mm.setMemoryStore(ms)
+    (mm, ms)
+  }
+
+  /**
+   * Assert that [[MemoryStore.ensureFreeSpace]] is called with the given parameters.
+   */
+  private def assertEnsureFreeSpaceCalled(
+      ms: MemoryStore,
+      blockId: BlockId,
+      numBytes: Long): Unit = {
+    verify(ms).ensureFreeSpace(meq(blockId), meq(numBytes: java.lang.Long), any())
+    reset(ms)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index eb5af70..cc44c67 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -29,6 +29,7 @@ import org.scalatest.concurrent.Eventually._
 import org.apache.spark.network.netty.NettyBlockTransferService
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark._
+import org.apache.spark.memory.StaticMemoryManager
 import org.apache.spark.network.BlockTransferService
 import org.apache.spark.scheduler.LiveListenerBus
 import org.apache.spark.serializer.KryoSerializer
@@ -39,29 +40,31 @@ import org.apache.spark.storage.StorageLevel._
 class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter {
 
   private val conf = new SparkConf(false).set("spark.app.id", "test")
-  var rpcEnv: RpcEnv = null
-  var master: BlockManagerMaster = null
-  val securityMgr = new SecurityManager(conf)
-  val mapOutputTracker = new MapOutputTrackerMaster(conf)
-  val shuffleManager = new HashShuffleManager(conf)
+  private var rpcEnv: RpcEnv = null
+  private var master: BlockManagerMaster = null
+  private val securityMgr = new SecurityManager(conf)
+  private val mapOutputTracker = new MapOutputTrackerMaster(conf)
+  private val shuffleManager = new HashShuffleManager(conf)
 
   // List of block manager created during an unit test, so that all of the them can be stopped
   // after the unit test.
-  val allStores = new ArrayBuffer[BlockManager]
+  private val allStores = new ArrayBuffer[BlockManager]
 
   // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
   conf.set("spark.kryoserializer.buffer", "1m")
-  val serializer = new KryoSerializer(conf)
+  private val serializer = new KryoSerializer(conf)
 
   // Implicitly convert strings to BlockIds for test clarity.
-  implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
+  private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
 
   private def makeBlockManager(
       maxMem: Long,
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
-    val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf,
-      mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+    val store = new BlockManager(name, rpcEnv, master, serializer, conf,
+      memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
+    memManager.setMemoryStore(store.memoryStore)
     store.initialize("app-id")
     allStores += store
     store
@@ -258,8 +261,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
     val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work
     when(failableTransfer.hostName).thenReturn("some-hostname")
     when(failableTransfer.port).thenReturn(1000)
-    val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer,
-      10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0)
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000)
+    val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf,
+      memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0)
+    memManager.setMemoryStore(failableStore.memoryStore)
     failableStore.initialize("app-id")
     allStores += failableStore // so that this gets stopped after test
     assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId))

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 34bb495..f3fab33 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark._
 import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.memory.StaticMemoryManager
 import org.apache.spark.scheduler.LiveListenerBus
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
 import org.apache.spark.shuffle.hash.HashShuffleManager
@@ -67,10 +68,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       maxMem: Long,
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
-    val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf,
-      mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
-    manager.initialize("app-id")
-    manager
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+    val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf,
+      memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
+    memManager.setMemoryStore(blockManager.memoryStore)
+    blockManager.initialize("app-id")
+    blockManager
   }
 
   override def beforeEach(): Unit = {
@@ -820,9 +823,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
   test("block store put failure") {
     // Use Java serializer so we can create an unserializable error.
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
+    val memoryManager = new StaticMemoryManager(conf, Long.MaxValue, 1200)
     store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
-      new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr,
-      0)
+      new JavaSerializer(conf), conf, memoryManager, mapOutputTracker,
+      shuffleManager, transfer, securityMgr, 0)
+    memoryManager.setMemoryStore(store.memoryStore)
 
     // The put should fail since a1 is not serializable.
     class UnserializableClass
@@ -1043,14 +1048,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(memoryStore.currentUnrollMemory === 0)
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
+    def reserveUnrollMemoryForThisTask(memory: Long): Boolean = {
+      memoryStore.reserveUnrollMemoryForThisTask(
+        TestBlockId(""), memory, new ArrayBuffer[(BlockId, BlockStatus)])
+    }
+
     // Reserve
-    memoryStore.reserveUnrollMemoryForThisTask(100)
+    assert(reserveUnrollMemoryForThisTask(100))
     assert(memoryStore.currentUnrollMemoryForThisTask === 100)
-    memoryStore.reserveUnrollMemoryForThisTask(200)
+    assert(reserveUnrollMemoryForThisTask(200))
     assert(memoryStore.currentUnrollMemoryForThisTask === 300)
-    memoryStore.reserveUnrollMemoryForThisTask(500)
+    assert(reserveUnrollMemoryForThisTask(500))
     assert(memoryStore.currentUnrollMemoryForThisTask === 800)
-    memoryStore.reserveUnrollMemoryForThisTask(1000000)
+    assert(!reserveUnrollMemoryForThisTask(1000000))
     assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted
     // Release
     memoryStore.releaseUnrollMemoryForThisTask(100)
@@ -1058,9 +1068,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     memoryStore.releaseUnrollMemoryForThisTask(100)
     assert(memoryStore.currentUnrollMemoryForThisTask === 600)
     // Reserve again
-    memoryStore.reserveUnrollMemoryForThisTask(4400)
+    assert(reserveUnrollMemoryForThisTask(4400))
     assert(memoryStore.currentUnrollMemoryForThisTask === 5000)
-    memoryStore.reserveUnrollMemoryForThisTask(20000)
+    assert(!reserveUnrollMemoryForThisTask(20000))
     assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted
     // Release again
     memoryStore.releaseUnrollMemoryForThisTask(1000)

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
index 48c3938..ff65d7b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
@@ -17,12 +17,18 @@
 
 package org.apache.spark.sql.execution
 
+import scala.collection.mutable
+
+import org.apache.spark.memory.MemoryManager
 import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.storage.{BlockId, BlockStatus}
+
 
 /**
  * A [[ShuffleMemoryManager]] that can be controlled to run out of memory.
  */
-class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1024 * 1024) {
+class TestShuffleMemoryManager
+  extends ShuffleMemoryManager(new GrantEverythingMemoryManager, 4 * 1024 * 1024) {
   private var oom = false
 
   override def tryToAcquire(numBytes: Long): Long = {
@@ -49,3 +55,23 @@ class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue, 4 * 1
     oom = true
   }
 }
+
+private class GrantEverythingMemoryManager extends MemoryManager {
+  override def acquireExecutionMemory(numBytes: Long): Long = numBytes
+  override def acquireStorageMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
+  override def acquireUnrollMemory(
+      blockId: BlockId,
+      numBytes: Long,
+      evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
+  override def releaseExecutionMemory(numBytes: Long): Unit = { }
+  override def releaseStorageMemory(numBytes: Long): Unit = { }
+  override def releaseStorageMemory(): Unit = { }
+  override def releaseUnrollMemory(numBytes: Long): Unit = { }
+  override def maxExecutionMemory: Long = Long.MaxValue
+  override def maxStorageMemory: Long = Long.MaxValue
+  override def executionMemoryUsed: Long = Long.MaxValue
+  override def storageMemoryUsed: Long = Long.MaxValue
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/67fbecbf/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 13cfe29..b2b6848 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -29,6 +29,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
 import org.scalatest.concurrent.Eventually._
 
 import org.apache.spark._
+import org.apache.spark.memory.StaticMemoryManager
 import org.apache.spark.network.netty.NettyBlockTransferService
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.LiveListenerBus
@@ -253,12 +254,14 @@ class ReceivedBlockHandlerSuite
       maxMem: Long,
       conf: SparkConf,
       name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
+    val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
     val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
-    val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf,
-      mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
-    manager.initialize("app-id")
-    blockManagerBuffer += manager
-    manager
+    val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf,
+      memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
+    memManager.setMemoryStore(blockManager.memoryStore)
+    blockManager.initialize("app-id")
+    blockManagerBuffer += blockManager
+    blockManager
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org