You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/04/08 08:41:00 UTC
[2/3] [SPARK-1103] Automatic garbage collection of RDD,
shuffle and broadcast data
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 19138d9..b021564 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -19,20 +19,22 @@ package org.apache.spark.storage
import java.io.{File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
+
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Random
+
import akka.actor.{ActorSystem, Cancellable, Props}
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import sun.nio.ch.DirectBuffer
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}
+
+import org.apache.spark.{Logging, MapOutputTracker, SecurityManager, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
import org.apache.spark.util._
-
sealed trait Values
case class ByteBufferValues(buffer: ByteBuffer) extends Values
@@ -46,7 +48,8 @@ private[spark] class BlockManager(
val defaultSerializer: Serializer,
maxMemory: Long,
val conf: SparkConf,
- securityManager: SecurityManager)
+ securityManager: SecurityManager,
+ mapOutputTracker: MapOutputTracker)
extends Logging {
val shuffleBlockManager = new ShuffleBlockManager(this)
@@ -55,7 +58,7 @@ private[spark] class BlockManager(
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
- private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
+ private[storage] val memoryStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
var tachyonInitialized = false
private[storage] lazy val tachyonStore: TachyonStore = {
@@ -98,7 +101,7 @@ private[spark] class BlockManager(
val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf)
- val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
+ val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
// Pending re-registration action being executed asynchronously or null if none
@@ -137,9 +140,10 @@ private[spark] class BlockManager(
master: BlockManagerMaster,
serializer: Serializer,
conf: SparkConf,
- securityManager: SecurityManager) = {
+ securityManager: SecurityManager,
+ mapOutputTracker: MapOutputTracker) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
- conf, securityManager)
+ conf, securityManager, mapOutputTracker)
}
/**
@@ -217,9 +221,26 @@ private[spark] class BlockManager(
}
/**
- * Get storage level of local block. If no info exists for the block, then returns null.
+ * Get the BlockStatus for the block identified by the given ID, if it exists.
+ * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon.
+ */
+ def getStatus(blockId: BlockId): Option[BlockStatus] = {
+ blockInfo.get(blockId).map { info =>
+ val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
+ val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L
+ // Assume that block is not in Tachyon
+ BlockStatus(info.level, memSize, diskSize, 0L)
+ }
+ }
+
+ /**
+ * Get the ids of existing blocks that match the given filter. Note that this will
+ * query the blocks stored in the disk block manager (that the block manager
+ * may not know of).
*/
- def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
+ def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = {
+ (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq
+ }
/**
* Tell the master about the current storage status of a block. This will send a block update
@@ -525,9 +546,8 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
- * The Block will be appended to the File specified by filename.
- * This is currently used for writing shuffle files out. Callers should handle error
- * cases.
+ * The Block will be appended to the File specified by filename. This is currently used for
+ * writing shuffle files out. Callers should handle error cases.
*/
def getDiskWriter(
blockId: BlockId,
@@ -863,11 +883,22 @@ private[spark] class BlockManager(
* @return The number of blocks removed.
*/
def removeRdd(rddId: Int): Int = {
- // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
- // from RDD.id to blocks.
+ // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
logInfo("Removing RDD " + rddId)
val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
- blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
+ blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
+ blocksToRemove.size
+ }
+
+ /**
+ * Remove all blocks belonging to the given broadcast.
+ */
+ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
+ logInfo("Removing broadcast " + broadcastId)
+ val blocksToRemove = blockInfo.keys.collect {
+ case bid @ BroadcastBlockId(`broadcastId`, _) => bid
+ }
+ blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
blocksToRemove.size
}
@@ -908,10 +939,10 @@ private[spark] class BlockManager(
}
private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
- val iterator = blockInfo.internalMap.entrySet().iterator()
+ val iterator = blockInfo.getEntrySet.iterator
while (iterator.hasNext) {
val entry = iterator.next()
- val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
+ val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp)
if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
@@ -935,7 +966,7 @@ private[spark] class BlockManager(
def shouldCompress(blockId: BlockId): Boolean = blockId match {
case ShuffleBlockId(_, _, _) => compressShuffle
- case BroadcastBlockId(_) => compressBroadcast
+ case BroadcastBlockId(_, _) => compressBroadcast
case RDDBlockId(_, _) => compressRdds
case TempBlockId(_) => compressShuffleSpill
case _ => false
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 4bc1b40..7897fad 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -81,6 +81,14 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
+ /**
+ * Check if block manager master has a block. Note that this can be used to check for only
+ * those blocks that are reported to block manager master.
+ */
+ def contains(blockId: BlockId) = {
+ !getLocations(blockId).isEmpty
+ }
+
/** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
@@ -99,12 +107,10 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply(RemoveBlock(blockId))
}
- /**
- * Remove all blocks belonging to the given RDD.
- */
+ /** Remove all blocks belonging to the given RDD. */
def removeRdd(rddId: Int, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
- future onFailure {
+ future.onFailure {
case e: Throwable => logError("Failed to remove RDD " + rddId, e)
}
if (blocking) {
@@ -112,6 +118,31 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
}
}
+ /** Remove all blocks belonging to the given shuffle. */
+ def removeShuffle(shuffleId: Int, blocking: Boolean) {
+ val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
+ future.onFailure {
+ case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e)
+ }
+ if (blocking) {
+ Await.result(future, timeout)
+ }
+ }
+
+ /** Remove all blocks belonging to the given broadcast. */
+ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
+ val future = askDriverWithReply[Future[Seq[Int]]](
+ RemoveBroadcast(broadcastId, removeFromMaster))
+ future.onFailure {
+ case e: Throwable =>
+ logError("Failed to remove broadcast " + broadcastId +
+ " with removeFromMaster = " + removeFromMaster, e)
+ }
+ if (blocking) {
+ Await.result(future, timeout)
+ }
+ }
+
/**
* Return the memory status for each block manager, in the form of a map from
* the block manager's id to two long values. The first value is the maximum
@@ -126,6 +157,51 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
}
+ /**
+ * Return the block's status on all block managers, if any. NOTE: This is a
+ * potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, this invokes the master to query each block manager for the most
+ * updated block statuses. This is useful when the master is not informed of the given block
+ * by all block managers.
+ */
+ def getBlockStatus(
+ blockId: BlockId,
+ askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = {
+ val msg = GetBlockStatus(blockId, askSlaves)
+ /*
+ * To avoid potential deadlocks, the use of Futures is necessary, because the master actor
+ * should not block on waiting for a block manager, which can in turn be waiting for the
+ * master actor for a response to a prior message.
+ */
+ val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
+ val (blockManagerIds, futures) = response.unzip
+ val result = Await.result(Future.sequence(futures), timeout)
+ if (result == null) {
+ throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId)
+ }
+ val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]]
+ blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) =>
+ status.map { s => (blockManagerId, s) }
+ }.toMap
+ }
+
+ /**
+ * Return a list of ids of existing blocks such that the ids match the given filter. NOTE: This
+ * is a potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, this invokes the master to query each block manager for the most
+ * updated block statuses. This is useful when the master is not informed of the given block
+ * by all block managers.
+ */
+ def getMatchingBlockIds(
+ filter: BlockId => Boolean,
+ askSlaves: Boolean): Seq[BlockId] = {
+ val msg = GetMatchingBlockIds(filter, askSlaves)
+ val future = askDriverWithReply[Future[Seq[BlockId]]](msg)
+ Await.result(future, timeout)
+ }
+
/** Stop the driver actor, called only on the Spark driver node */
def stop() {
if (driverActor != null) {
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 378f4ca..c57b6e8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -94,9 +94,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetStorageStatus =>
sender ! storageStatus
+ case GetBlockStatus(blockId, askSlaves) =>
+ sender ! blockStatus(blockId, askSlaves)
+
+ case GetMatchingBlockIds(filter, askSlaves) =>
+ sender ! getMatchingBlockIds(filter, askSlaves)
+
case RemoveRdd(rddId) =>
sender ! removeRdd(rddId)
+ case RemoveShuffle(shuffleId) =>
+ sender ! removeShuffle(shuffleId)
+
+ case RemoveBroadcast(broadcastId, removeFromDriver) =>
+ sender ! removeBroadcast(broadcastId, removeFromDriver)
+
case RemoveBlock(blockId) =>
removeBlockFromWorkers(blockId)
sender ! true
@@ -140,9 +152,41 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
// The dispatcher is used as an implicit argument into the Future sequence construction.
import context.dispatcher
val removeMsg = RemoveRdd(rddId)
- Future.sequence(blockManagerInfo.values.map { bm =>
- bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
- }.toSeq)
+ Future.sequence(
+ blockManagerInfo.values.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ }.toSeq
+ )
+ }
+
+ private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
+ // Nothing to do in the BlockManagerMasterActor data structures
+ import context.dispatcher
+ val removeMsg = RemoveShuffle(shuffleId)
+ Future.sequence(
+ blockManagerInfo.values.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean]
+ }.toSeq
+ )
+ }
+
+ /**
+ * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
+ * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
+ * from the executors, but not from the driver.
+ */
+ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
+ // TODO: Consolidate usages of <driver>
+ import context.dispatcher
+ val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
+ val requiredBlockManagers = blockManagerInfo.values.filter { info =>
+ removeFromDriver || info.blockManagerId.executorId != "<driver>"
+ }
+ Future.sequence(
+ requiredBlockManagers.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ }.toSeq
+ )
}
private def removeBlockManager(blockManagerId: BlockManagerId) {
@@ -225,6 +269,61 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}.toArray
}
+ /**
+ * Return the block's status for all block managers, if any. NOTE: This is a
+ * potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, the master queries each block manager for the most updated block
+ * statuses. This is useful when the master is not informed of the given block by all block
+ * managers.
+ */
+ private def blockStatus(
+ blockId: BlockId,
+ askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
+ import context.dispatcher
+ val getBlockStatus = GetBlockStatus(blockId)
+ /*
+ * Rather than blocking on the block status query, master actor should simply return
+ * Futures to avoid potential deadlocks. This can arise if there exists a block manager
+ * that is also waiting for this master actor's response to a previous message.
+ */
+ blockManagerInfo.values.map { info =>
+ val blockStatusFuture =
+ if (askSlaves) {
+ info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]]
+ } else {
+ Future { info.getStatus(blockId) }
+ }
+ (info.blockManagerId, blockStatusFuture)
+ }.toMap
+ }
+
+ /**
+ * Return the ids of blocks present in all the block managers that match the given filter.
+ * NOTE: This is a potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, the master queries each block manager for the most updated block
+ * statuses. This is useful when the master is not informed of the given block by all block
+ * managers.
+ */
+ private def getMatchingBlockIds(
+ filter: BlockId => Boolean,
+ askSlaves: Boolean): Future[Seq[BlockId]] = {
+ import context.dispatcher
+ val getMatchingBlockIds = GetMatchingBlockIds(filter)
+ Future.sequence(
+ blockManagerInfo.values.map { info =>
+ val future =
+ if (askSlaves) {
+ info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]]
+ } else {
+ Future { info.blocks.keys.filter(filter).toSeq }
+ }
+ future
+ }
+ ).map(_.flatten.toSeq)
+ }
+
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
@@ -334,6 +433,8 @@ private[spark] class BlockManagerInfo(
logInfo("Registering block manager %s with %s RAM".format(
blockManagerId.hostPort, Utils.bytesToString(maxMem)))
+ def getStatus(blockId: BlockId) = Option(_blocks.get(blockId))
+
def updateLastSeenMs() {
_lastSeenMs = System.currentTimeMillis()
}
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 8a36b5c..2b53bf3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -34,6 +34,13 @@ private[storage] object BlockManagerMessages {
// Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
+ // Remove all blocks belonging to a specific shuffle.
+ case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave
+
+ // Remove all blocks belonging to a specific broadcast.
+ case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true)
+ extends ToBlockManagerSlave
+
//////////////////////////////////////////////////////////////////////////////////
// Messages from slaves to the master.
@@ -80,7 +87,8 @@ private[storage] object BlockManagerMessages {
}
object UpdateBlockInfo {
- def apply(blockManagerId: BlockManagerId,
+ def apply(
+ blockManagerId: BlockManagerId,
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
@@ -108,7 +116,13 @@ private[storage] object BlockManagerMessages {
case object GetMemoryStatus extends ToBlockManagerMaster
- case object ExpireDeadHosts extends ToBlockManagerMaster
-
case object GetStorageStatus extends ToBlockManagerMaster
+
+ case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true)
+ extends ToBlockManagerMaster
+
+ case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true)
+ extends ToBlockManagerMaster
+
+ case object ExpireDeadHosts extends ToBlockManagerMaster
}
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index bcfb82d..6d4db06 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -17,8 +17,11 @@
package org.apache.spark.storage
-import akka.actor.Actor
+import scala.concurrent.Future
+import akka.actor.{ActorRef, Actor}
+
+import org.apache.spark.{Logging, MapOutputTracker}
import org.apache.spark.storage.BlockManagerMessages._
/**
@@ -26,14 +29,59 @@ import org.apache.spark.storage.BlockManagerMessages._
* this is used to remove blocks from the slave's BlockManager.
*/
private[storage]
-class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
- override def receive = {
+class BlockManagerSlaveActor(
+ blockManager: BlockManager,
+ mapOutputTracker: MapOutputTracker)
+ extends Actor with Logging {
+
+ import context.dispatcher
+ // Operations that involve removing blocks may be slow and should be done asynchronously
+ override def receive = {
case RemoveBlock(blockId) =>
- blockManager.removeBlock(blockId)
+ doAsync[Boolean]("removing block " + blockId, sender) {
+ blockManager.removeBlock(blockId)
+ true
+ }
case RemoveRdd(rddId) =>
- val numBlocksRemoved = blockManager.removeRdd(rddId)
- sender ! numBlocksRemoved
+ doAsync[Int]("removing RDD " + rddId, sender) {
+ blockManager.removeRdd(rddId)
+ }
+
+ case RemoveShuffle(shuffleId) =>
+ doAsync[Boolean]("removing shuffle " + shuffleId, sender) {
+ if (mapOutputTracker != null) {
+ mapOutputTracker.unregisterShuffle(shuffleId)
+ }
+ blockManager.shuffleBlockManager.removeShuffle(shuffleId)
+ }
+
+ case RemoveBroadcast(broadcastId, tellMaster) =>
+ doAsync[Int]("removing broadcast " + broadcastId, sender) {
+ blockManager.removeBroadcast(broadcastId, tellMaster)
+ }
+
+ case GetBlockStatus(blockId, _) =>
+ sender ! blockManager.getStatus(blockId)
+
+ case GetMatchingBlockIds(filter, _) =>
+ sender ! blockManager.getMatchingBlockIds(filter)
+ }
+
+ private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) {
+ val future = Future {
+ logDebug(actionMessage)
+ body
+ }
+ future.onSuccess { case response =>
+ logDebug("Done " + actionMessage + ", response is " + response)
+ responseActor ! response
+ logDebug("Sent response: " + response + " to " + responseActor)
+ }
+ future.onFailure { case t: Throwable =>
+ logError("Error in " + actionMessage, t)
+ responseActor ! null.asInstanceOf[T]
+ }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index f3e1c38..7a24c8f 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -90,6 +90,20 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
def getFile(blockId: BlockId): File = getFile(blockId.name)
+ /** Check if disk block manager has a block. */
+ def containsBlock(blockId: BlockId): Boolean = {
+ getBlockLocation(blockId).file.exists()
+ }
+
+ /** List all the blocks currently stored on disk by the disk manager. */
+ def getAllBlocks(): Seq[BlockId] = {
+ // Get all the files inside the array of array of directories
+ subDirs.flatten.filter(_ != null).flatMap { dir =>
+ val files = dir.list()
+ if (files != null) files else Seq.empty
+ }.map(BlockId.apply)
+ }
+
/** Produces a unique block id and File suitable for intermediate results. */
def createTempBlock(): (TempBlockId, File) = {
var blockId = new TempBlockId(UUID.randomUUID())
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index bb07c8c..4cd4cdb 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -169,23 +169,43 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
throw new IllegalStateException("Failed to find shuffle block: " + id)
}
+ /** Remove all the blocks / files and metadata related to a particular shuffle. */
+ def removeShuffle(shuffleId: ShuffleId): Boolean = {
+ // Do not change the ordering of this, if shuffleStates should be removed only
+ // after the corresponding shuffle blocks have been removed
+ val cleaned = removeShuffleBlocks(shuffleId)
+ shuffleStates.remove(shuffleId)
+ cleaned
+ }
+
+ /** Remove all the blocks / files related to a particular shuffle. */
+ private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
+ shuffleStates.get(shuffleId) match {
+ case Some(state) =>
+ if (consolidateShuffleFiles) {
+ for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
+ file.delete()
+ }
+ } else {
+ for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
+ val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
+ blockManager.diskBlockManager.getFile(blockId).delete()
+ }
+ }
+ logInfo("Deleted all files for shuffle " + shuffleId)
+ true
+ case None =>
+ logInfo("Could not find files for shuffle " + shuffleId + " for deleting")
+ false
+ }
+ }
+
private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
"merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
}
private def cleanup(cleanupTime: Long) {
- shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => {
- if (consolidateShuffleFiles) {
- for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
- file.delete()
- }
- } else {
- for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
- val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
- blockManager.diskBlockManager.getFile(blockId).delete()
- }
- }
- })
+ shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 226ed2a..a107c51 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.ArrayBlockingQueue
import akka.actor._
import util.Random
-import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.KryoSerializer
@@ -48,7 +48,7 @@ private[spark] object ThreadingTest {
val block = (1 to blockSize).map(_ => Random.nextInt())
val level = randomLevel()
val startTime = System.currentTimeMillis()
- manager.put(blockId, block.iterator, level, true)
+ manager.put(blockId, block.iterator, level, tellMaster = true)
println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
queue.add((blockId, block))
}
@@ -101,7 +101,7 @@ private[spark] object ThreadingTest {
conf)
val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
- new SecurityManager(conf))
+ new SecurityManager(conf), new MapOutputTrackerMaster(conf))
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 0448919..7ebed51 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -62,8 +62,8 @@ private[spark] class MetadataCleaner(
private[spark] object MetadataCleanerType extends Enumeration {
- val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
- SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER,
+ SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
@@ -78,15 +78,16 @@ private[spark] object MetadataCleaner {
conf.getInt("spark.cleaner.ttl", -1)
}
- def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int =
- {
- conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString)
- .toInt
+ def getDelaySeconds(
+ conf: SparkConf,
+ cleanerType: MetadataCleanerType.MetadataCleanerType): Int = {
+ conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt
}
- def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType,
- delay: Int)
- {
+ def setDelaySeconds(
+ conf: SparkConf,
+ cleanerType: MetadataCleanerType.MetadataCleanerType,
+ delay: Int) {
conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index ddbd084..8de75ba 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -17,48 +17,54 @@
package org.apache.spark.util
+import java.util.Set
+import java.util.Map.Entry
import java.util.concurrent.ConcurrentHashMap
-import scala.collection.JavaConversions
-import scala.collection.immutable
-import scala.collection.mutable.Map
+import scala.collection.{JavaConversions, mutable}
import org.apache.spark.Logging
+private[spark] case class TimeStampedValue[V](value: V, timestamp: Long)
+
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
* timestamp along with each key-value pair. If specified, the timestamp of each pair can be
* updated every time it is accessed. Key-value pairs whose timestamp are older than a particular
* threshold time can then be removed using the clearOldValues method. This is intended to
* be a drop-in replacement of scala.collection.mutable.HashMap.
- * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be
- * updated when it is accessed
+ *
+ * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed
*/
-class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
- extends Map[A, B]() with Logging {
- val internalMap = new ConcurrentHashMap[A, (B, Long)]()
+private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
+ extends mutable.Map[A, B]() with Logging {
+
+ private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]()
def get(key: A): Option[B] = {
val value = internalMap.get(key)
if (value != null && updateTimeStampOnGet) {
- internalMap.replace(key, value, (value._1, currentTime))
+ internalMap.replace(key, value, TimeStampedValue(value.value, currentTime))
}
- Option(value).map(_._1)
+ Option(value).map(_.value)
}
def iterator: Iterator[(A, B)] = {
- val jIterator = internalMap.entrySet().iterator()
- JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1))
+ val jIterator = getEntrySet.iterator
+ JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value))
}
- override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = {
+ def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet
+
+ override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
val newMap = new TimeStampedHashMap[A, B1]
- newMap.internalMap.putAll(this.internalMap)
- newMap.internalMap.put(kv._1, (kv._2, currentTime))
+ val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]]
+ newMap.internalMap.putAll(oldInternalMap)
+ kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) }
newMap
}
- override def - (key: A): Map[A, B] = {
+ override def - (key: A): mutable.Map[A, B] = {
val newMap = new TimeStampedHashMap[A, B]
newMap.internalMap.putAll(this.internalMap)
newMap.internalMap.remove(key)
@@ -66,17 +72,10 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
}
override def += (kv: (A, B)): this.type = {
- internalMap.put(kv._1, (kv._2, currentTime))
+ kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) }
this
}
- // Should we return previous value directly or as Option ?
- def putIfAbsent(key: A, value: B): Option[B] = {
- val prev = internalMap.putIfAbsent(key, (value, currentTime))
- if (prev != null) Some(prev._1) else None
- }
-
-
override def -= (key: A): this.type = {
internalMap.remove(key)
this
@@ -87,53 +86,65 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
}
override def apply(key: A): B = {
- val value = internalMap.get(key)
- if (value == null) throw new NoSuchElementException()
- value._1
+ get(key).getOrElse { throw new NoSuchElementException() }
}
- override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
- JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
+ override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = {
+ JavaConversions.mapAsScalaConcurrentMap(internalMap)
+ .map { case (k, TimeStampedValue(v, t)) => (k, v) }
+ .filter(p)
}
- override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
+ override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]()
override def size: Int = internalMap.size
override def foreach[U](f: ((A, B)) => U) {
- val iterator = internalMap.entrySet().iterator()
- while(iterator.hasNext) {
- val entry = iterator.next()
- val kv = (entry.getKey, entry.getValue._1)
+ val it = getEntrySet.iterator
+ while(it.hasNext) {
+ val entry = it.next()
+ val kv = (entry.getKey, entry.getValue.value)
f(kv)
}
}
- def toMap: immutable.Map[A, B] = iterator.toMap
+ def putIfAbsent(key: A, value: B): Option[B] = {
+ val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime))
+ Option(prev).map(_.value)
+ }
+
+ def putAll(map: Map[A, B]) {
+ map.foreach { case (k, v) => update(k, v) }
+ }
+
+ def toMap: Map[A, B] = iterator.toMap
- /**
- * Removes old key-value pairs that have timestamp earlier than `threshTime`,
- * calling the supplied function on each such entry before removing.
- */
def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
- val iterator = internalMap.entrySet().iterator()
- while (iterator.hasNext) {
- val entry = iterator.next()
- if (entry.getValue._2 < threshTime) {
- f(entry.getKey, entry.getValue._1)
+ val it = getEntrySet.iterator
+ while (it.hasNext) {
+ val entry = it.next()
+ if (entry.getValue.timestamp < threshTime) {
+ f(entry.getKey, entry.getValue.value)
logDebug("Removing key " + entry.getKey)
- iterator.remove()
+ it.remove()
}
}
}
- /**
- * Removes old key-value pairs that have timestamp earlier than `threshTime`
- */
+ /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */
def clearOldValues(threshTime: Long) {
clearOldValues(threshTime, (_, _) => ())
}
- private def currentTime: Long = System.currentTimeMillis()
+ private def currentTime: Long = System.currentTimeMillis
+ // For testing
+
+ def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = {
+ Option(internalMap.get(key))
+ }
+
+ def getTimestamp(key: A): Option[Long] = {
+ getTimeStampedValue(key).map(_.timestamp)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
new file mode 100644
index 0000000..b65017d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.lang.ref.WeakReference
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+
+/**
+ * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped.
+ *
+ * If the value is garbage collected and the weak reference is null, get() will return a
+ * non-existent value. These entries are removed from the map periodically (every N inserts), as
+ * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are
+ * older than a particular threshold can be removed using the clearOldValues method.
+ *
+ * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it
+ * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap,
+ * so all operations on this HashMap are thread-safe.
+ *
+ * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed.
+ */
+private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false)
+ extends mutable.Map[A, B]() with Logging {
+
+ import TimeStampedWeakValueHashMap._
+
+ private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet)
+ private val insertCount = new AtomicInteger(0)
+
+ /** Return a map consisting only of entries whose values are still strongly reachable. */
+ private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null }
+
+ def get(key: A): Option[B] = internalMap.get(key)
+
+ def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator
+
+ override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
+ val newMap = new TimeStampedWeakValueHashMap[A, B1]
+ val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]]
+ newMap.internalMap.putAll(oldMap.toMap)
+ newMap.internalMap += kv
+ newMap
+ }
+
+ override def - (key: A): mutable.Map[A, B] = {
+ val newMap = new TimeStampedWeakValueHashMap[A, B]
+ newMap.internalMap.putAll(nonNullReferenceMap.toMap)
+ newMap.internalMap -= key
+ newMap
+ }
+
+ override def += (kv: (A, B)): this.type = {
+ internalMap += kv
+ if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) {
+ clearNullValues()
+ }
+ this
+ }
+
+ override def -= (key: A): this.type = {
+ internalMap -= key
+ this
+ }
+
+ override def update(key: A, value: B) = this += ((key, value))
+
+ override def apply(key: A): B = internalMap.apply(key)
+
+ override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p)
+
+ override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]()
+
+ override def size: Int = internalMap.size
+
+ override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f)
+
+ def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value)
+
+ def toMap: Map[A, B] = iterator.toMap
+
+ /** Remove old key-value pairs with timestamps earlier than `threshTime`. */
+ def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime)
+
+ /** Remove entries with values that are no longer strongly reachable. */
+ def clearNullValues() {
+ val it = internalMap.getEntrySet.iterator
+ while (it.hasNext) {
+ val entry = it.next()
+ if (entry.getValue.value.get == null) {
+ logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.")
+ it.remove()
+ }
+ }
+ }
+
+ // For testing
+
+ def getTimestamp(key: A): Option[Long] = {
+ internalMap.getTimeStampedValue(key).map(_.timestamp)
+ }
+
+ def getReference(key: A): Option[WeakReference[B]] = {
+ internalMap.getTimeStampedValue(key).map(_.value)
+ }
+}
+
+/**
+ * Helper methods for converting to and from WeakReferences.
+ */
+private object TimeStampedWeakValueHashMap {
+
+ // Number of inserts after which entries with null references are removed
+ val CLEAR_NULL_VALUES_INTERVAL = 100
+
+ /* Implicit conversion methods to WeakReferences. */
+
+ implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v)
+
+ implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = {
+ kv match { case (k, v) => (k, toWeakReference(v)) }
+ }
+
+ implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = {
+ (kv: (K, WeakReference[V])) => p(kv)
+ }
+
+ /* Implicit conversion methods from WeakReferences. */
+
+ implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get
+
+ implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = {
+ v match {
+ case Some(ref) => Option(fromWeakReference(ref))
+ case None => None
+ }
+ }
+
+ implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = {
+ kv match { case (k, v) => (k, fromWeakReference(v)) }
+ }
+
+ implicit def fromWeakReferenceIterator[K, V](
+ it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = {
+ it.map(fromWeakReferenceTuple)
+ }
+
+ implicit def fromWeakReferenceMap[K, V](
+ map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = {
+ mutable.Map(map.mapValues(fromWeakReference).toSeq: _*)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 4435b21..59da51f 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -499,10 +499,10 @@ private[spark] object Utils extends Logging {
private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
def parseHostPort(hostPort: String): (String, Int) = {
- {
- // Check cache first.
- val cached = hostPortParseResults.get(hostPort)
- if (cached != null) return cached
+ // Check cache first.
+ val cached = hostPortParseResults.get(hostPort)
+ if (cached != null) {
+ return cached
}
val indx: Int = hostPort.lastIndexOf(':')
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
index d2e303d..c5f24c6 100644
--- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
@@ -56,7 +56,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = conf, securityManager = securityManagerBad)
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
@@ -93,7 +93,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = badconf, securityManager = securityManagerBad)
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
@@ -147,7 +147,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = goodconf, securityManager = securityManagerGood)
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
@@ -200,7 +200,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = badconf, securityManager = securityManagerBad)
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index 96ba392..c993625 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -19,68 +19,297 @@ package org.apache.spark
import org.scalatest.FunSuite
-class BroadcastSuite extends FunSuite with LocalSparkContext {
+import org.apache.spark.storage._
+import org.apache.spark.broadcast.{Broadcast, HttpBroadcast}
+import org.apache.spark.storage.BroadcastBlockId
+class BroadcastSuite extends FunSuite with LocalSparkContext {
- override def afterEach() {
- super.afterEach()
- System.clearProperty("spark.broadcast.factory")
- }
+ private val httpConf = broadcastConf("HttpBroadcastFactory")
+ private val torrentConf = broadcastConf("TorrentBroadcastFactory")
test("Using HttpBroadcast locally") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
- sc = new SparkContext("local", "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ sc = new SparkContext("local", "test", httpConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === Set((1, 10), (2, 10)))
}
test("Accessing HttpBroadcast variables from multiple threads") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
- sc = new SparkContext("local[10]", "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ sc = new SparkContext("local[10]", "test", httpConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
}
test("Accessing HttpBroadcast variables in a local cluster") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
val numSlaves = 4
- sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}
test("Using TorrentBroadcast locally") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
- sc = new SparkContext("local", "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ sc = new SparkContext("local", "test", torrentConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === Set((1, 10), (2, 10)))
}
test("Accessing TorrentBroadcast variables from multiple threads") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
- sc = new SparkContext("local[10]", "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ sc = new SparkContext("local[10]", "test", torrentConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
}
test("Accessing TorrentBroadcast variables in a local cluster") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
val numSlaves = 4
- sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
+ test("Unpersisting HttpBroadcast on executors only in local mode") {
+ testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
+ }
+
+ test("Unpersisting HttpBroadcast on executors and driver in local mode") {
+ testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true)
+ }
+
+ test("Unpersisting HttpBroadcast on executors only in distributed mode") {
+ testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false)
+ }
+
+ test("Unpersisting HttpBroadcast on executors and driver in distributed mode") {
+ testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors only in local mode") {
+ testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors and driver in local mode") {
+ testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors only in distributed mode") {
+ testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") {
+ testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true)
+ }
+ /**
+ * Verify the persistence of state associated with an HttpBroadcast in either local mode or
+ * local-cluster mode (when distributed = true).
+ *
+ * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
+ * In between each step, this test verifies that the broadcast blocks and the broadcast file
+ * are present only on the expected nodes.
+ */
+ private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
+ val numSlaves = if (distributed) 2 else 0
+
+ def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
+
+ // Verify that the broadcast file is created, and blocks are persisted only on the driver
+ def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ assert(blockIds.size === 1)
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ assert(statuses.size === 1)
+ statuses.head match { case (bm, status) =>
+ assert(bm.executorId === "<driver>", "Block should only be on the driver")
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store on the driver")
+ assert(status.diskSize === 0, "Block should not be in disk store on the driver")
+ }
+ if (distributed) {
+ // this file is only generated in distributed mode
+ assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
+ }
+ }
+
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ assert(blockIds.size === 1)
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ assert(statuses.size === numSlaves + 1)
+ statuses.foreach { case (_, status) =>
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store")
+ assert(status.diskSize === 0, "Block should not be in disk store")
+ }
+ }
+
+ // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
+ // is true. In the latter case, also verify that the broadcast file is deleted on the driver.
+ def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ assert(blockIds.size === 1)
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ val expectedNumBlocks = if (removeFromDriver) 0 else 1
+ val possiblyNot = if (removeFromDriver) "" else " not"
+ assert(statuses.size === expectedNumBlocks,
+ "Block should%s be unpersisted on the driver".format(possiblyNot))
+ if (distributed && removeFromDriver) {
+ // this file is only generated in distributed mode
+ assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
+ "Broadcast file should%s be deleted".format(possiblyNot))
+ }
+ }
+
+ testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
+ afterUsingBroadcast, afterUnpersist, removeFromDriver)
+ }
+
+ /**
+ * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster.
+ *
+ * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
+ * In between each step, this test verifies that the broadcast blocks are present only on the
+ * expected nodes.
+ */
+ private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
+ val numSlaves = if (distributed) 2 else 0
+
+ def getBlockIds(id: Long) = {
+ val broadcastBlockId = BroadcastBlockId(id)
+ val metaBlockId = BroadcastBlockId(id, "meta")
+ // Assume broadcast value is small enough to fit into 1 piece
+ val pieceBlockId = BroadcastBlockId(id, "piece0")
+ if (distributed) {
+ // the metadata and piece blocks are generated only in distributed mode
+ Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
+ } else {
+ Seq[BroadcastBlockId](broadcastBlockId)
+ }
+ }
+
+ // Verify that blocks are persisted only on the driver
+ def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ blockIds.foreach { blockId =>
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ assert(statuses.size === 1)
+ statuses.head match { case (bm, status) =>
+ assert(bm.executorId === "<driver>", "Block should only be on the driver")
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store on the driver")
+ assert(status.diskSize === 0, "Block should not be in disk store on the driver")
+ }
+ }
+ }
+
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ blockIds.foreach { blockId =>
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (blockId.field == "meta") {
+ // Meta data is only on the driver
+ assert(statuses.size === 1)
+ statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
+ } else {
+ // Other blocks are on both the executors and the driver
+ assert(statuses.size === numSlaves + 1,
+ blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
+ statuses.foreach { case (_, status) =>
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store")
+ assert(status.diskSize === 0, "Block should not be in disk store")
+ }
+ }
+ }
+ }
+
+ // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
+ // is true.
+ def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ val expectedNumBlocks = if (removeFromDriver) 0 else 1
+ val possiblyNot = if (removeFromDriver) "" else " not"
+ blockIds.foreach { blockId =>
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks,
+ "Block should%s be unpersisted on the driver".format(possiblyNot))
+ }
+ }
+
+ testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation,
+ afterUsingBroadcast, afterUnpersist, removeFromDriver)
+ }
+
+ /**
+ * This test runs in 4 steps:
+ *
+ * 1) Create broadcast variable, and verify that all state is persisted on the driver.
+ * 2) Use the broadcast variable on all executors, and verify that all state is persisted
+ * on both the driver and the executors.
+ * 3) Unpersist the broadcast, and verify that all state is removed where they should be.
+ * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable.
+ */
+ private def testUnpersistBroadcast(
+ distributed: Boolean,
+ numSlaves: Int, // used only when distributed = true
+ broadcastConf: SparkConf,
+ getBlockIds: Long => Seq[BroadcastBlockId],
+ afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ removeFromDriver: Boolean) {
+
+ sc = if (distributed) {
+ new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)
+ } else {
+ new SparkContext("local", "test", broadcastConf)
+ }
+ val blockManagerMaster = sc.env.blockManager.master
+ val list = List[Int](1, 2, 3, 4)
+
+ // Create broadcast variable
+ val broadcast = sc.broadcast(list)
+ val blocks = getBlockIds(broadcast.id)
+ afterCreation(blocks, blockManagerMaster)
+
+ // Use broadcast variable on all executors
+ val partitions = 10
+ assert(partitions > numSlaves)
+ val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
+ afterUsingBroadcast(blocks, blockManagerMaster)
+
+ // Unpersist broadcast
+ if (removeFromDriver) {
+ broadcast.destroy(blocking = true)
+ } else {
+ broadcast.unpersist(blocking = true)
+ }
+ afterUnpersist(blocks, blockManagerMaster)
+
+ // If the broadcast is removed from driver, all subsequent uses of the broadcast variable
+ // should throw SparkExceptions. Otherwise, the result should be the same as before.
+ if (removeFromDriver) {
+ // Using this variable on the executors crashes them, which hangs the test.
+ // Instead, crash the driver by directly accessing the broadcast value.
+ intercept[SparkException] { broadcast.value }
+ intercept[SparkException] { broadcast.unpersist() }
+ intercept[SparkException] { broadcast.destroy(blocking = true) }
+ } else {
+ val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
+ }
}
+ /** Helper method to create a SparkConf that uses the given broadcast factory. */
+ private def broadcastConf(factoryName: String): SparkConf = {
+ val conf = new SparkConf
+ conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName))
+ conf
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
new file mode 100644
index 0000000..e50981c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.lang.ref.WeakReference
+
+import scala.collection.mutable.{HashSet, SynchronizedSet}
+import scala.util.Random
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.concurrent.Eventually
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId}
+
+class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ implicit val defaultTimeout = timeout(10000 millis)
+ val conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("ContextCleanerSuite")
+ .set("spark.cleaner.referenceTracking.blocking", "true")
+
+ before {
+ sc = new SparkContext(conf)
+ }
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ }
+
+
+ test("cleanup RDD") {
+ val rdd = newRDD.persist()
+ val collected = rdd.collect().toList
+ val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+
+ // Explicit cleanup
+ cleaner.doCleanupRDD(rdd.id, blocking = true)
+ tester.assertCleanup()
+
+ // Verify that RDDs can be re-executed after cleaning up
+ assert(rdd.collect().toList === collected)
+ }
+
+ test("cleanup shuffle") {
+ val (rdd, shuffleDeps) = newRDDWithShuffleDependencies
+ val collected = rdd.collect().toList
+ val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
+
+ // Explicit cleanup
+ shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true))
+ tester.assertCleanup()
+
+ // Verify that shuffles can be re-executed after cleaning up
+ assert(rdd.collect().toList === collected)
+ }
+
+ test("cleanup broadcast") {
+ val broadcast = newBroadcast
+ val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
+
+ // Explicit cleanup
+ cleaner.doCleanupBroadcast(broadcast.id, blocking = true)
+ tester.assertCleanup()
+ }
+
+ test("automatically cleanup RDD") {
+ var rdd = newRDD.persist()
+ rdd.count()
+
+ // Test that GC does not cause RDD cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC causes RDD cleanup after dereferencing the RDD
+ val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+ rdd = null // Make RDD out of scope
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ test("automatically cleanup shuffle") {
+ var rdd = newShuffleRDD
+ rdd.count()
+
+ // Test that GC does not cause shuffle cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC causes shuffle cleanup after dereferencing the RDD
+ val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+ rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ test("automatically cleanup broadcast") {
+ var broadcast = newBroadcast
+
+ // Test that GC does not cause broadcast cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC causes broadcast cleanup after dereferencing the broadcast variable
+ val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
+ broadcast = null // Make broadcast variable out of scope
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ test("automatically cleanup RDD + shuffle + broadcast") {
+ val numRdds = 100
+ val numBroadcasts = 4 // Broadcasts are more costly
+ val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+ val rddIds = sc.persistentRdds.keys.toSeq
+ val shuffleIds = 0 until sc.newShuffleId
+ val broadcastIds = 0L until numBroadcasts
+
+ val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC triggers the cleanup of all variables after the dereferencing them
+ val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ broadcastBuffer.clear()
+ rddBuffer.clear()
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
+ sc.stop()
+
+ val conf2 = new SparkConf()
+ .setMaster("local-cluster[2, 1, 512]")
+ .setAppName("ContextCleanerSuite")
+ .set("spark.cleaner.referenceTracking.blocking", "true")
+ sc = new SparkContext(conf2)
+
+ val numRdds = 10
+ val numBroadcasts = 4 // Broadcasts are more costly
+ val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+ val rddIds = sc.persistentRdds.keys.toSeq
+ val shuffleIds = 0 until sc.newShuffleId
+ val broadcastIds = 0L until numBroadcasts
+
+ val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC triggers the cleanup of all variables after the dereferencing them
+ val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ broadcastBuffer.clear()
+ rddBuffer.clear()
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ //------ Helper functions ------
+
+ def newRDD = sc.makeRDD(1 to 10)
+ def newPairRDD = newRDD.map(_ -> 1)
+ def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
+ def newBroadcast = sc.broadcast(1 to 100)
+ def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
+ def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
+ rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
+ getAllDependencies(dep.rdd)
+ }
+ }
+ val rdd = newShuffleRDD
+
+ // Get all the shuffle dependencies
+ val shuffleDeps = getAllDependencies(rdd)
+ .filter(_.isInstanceOf[ShuffleDependency[_, _]])
+ .map(_.asInstanceOf[ShuffleDependency[_, _]])
+ (rdd, shuffleDeps)
+ }
+
+ def randomRdd = {
+ val rdd: RDD[_] = Random.nextInt(3) match {
+ case 0 => newRDD
+ case 1 => newShuffleRDD
+ case 2 => newPairRDD.join(newPairRDD)
+ }
+ if (Random.nextBoolean()) rdd.persist()
+ rdd.count()
+ rdd
+ }
+
+ def randomBroadcast = {
+ sc.broadcast(Random.nextInt(Int.MaxValue))
+ }
+
+ /** Run GC and make sure it actually has run */
+ def runGC() {
+ val weakRef = new WeakReference(new Object())
+ val startTime = System.currentTimeMillis
+ System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
+ // Wait until a weak reference object has been GCed
+ while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+ System.gc()
+ Thread.sleep(200)
+ }
+ }
+
+ def cleaner = sc.cleaner.get
+}
+
+
+/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */
+class CleanerTester(
+ sc: SparkContext,
+ rddIds: Seq[Int] = Seq.empty,
+ shuffleIds: Seq[Int] = Seq.empty,
+ broadcastIds: Seq[Long] = Seq.empty)
+ extends Logging {
+
+ val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds
+ val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds
+ val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds
+ val isDistributed = !sc.isLocal
+
+ val cleanerListener = new CleanerListener {
+ def rddCleaned(rddId: Int): Unit = {
+ toBeCleanedRDDIds -= rddId
+ logInfo("RDD "+ rddId + " cleaned")
+ }
+
+ def shuffleCleaned(shuffleId: Int): Unit = {
+ toBeCleanedShuffleIds -= shuffleId
+ logInfo("Shuffle " + shuffleId + " cleaned")
+ }
+
+ def broadcastCleaned(broadcastId: Long): Unit = {
+ toBeCleanedBroadcstIds -= broadcastId
+ logInfo("Broadcast" + broadcastId + " cleaned")
+ }
+ }
+
+ val MAX_VALIDATION_ATTEMPTS = 10
+ val VALIDATION_ATTEMPT_INTERVAL = 100
+
+ logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString)
+ preCleanupValidate()
+ sc.cleaner.get.attachListener(cleanerListener)
+
+ /** Assert that all the stuff has been cleaned up */
+ def assertCleanup()(implicit waitTimeout: Eventually.Timeout) {
+ try {
+ eventually(waitTimeout, interval(100 millis)) {
+ assert(isAllCleanedUp)
+ }
+ postCleanupValidate()
+ } finally {
+ logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString)
+ }
+ }
+
+ /** Verify that RDDs, shuffles, etc. occupy resources */
+ private def preCleanupValidate() {
+ assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup")
+
+ // Verify the RDDs have been persisted and blocks are present
+ rddIds.foreach { rddId =>
+ assert(
+ sc.persistentRdds.contains(rddId),
+ "RDD " + rddId + " have not been persisted, cannot start cleaner test"
+ )
+
+ assert(
+ !getRDDBlocks(rddId).isEmpty,
+ "Blocks of RDD " + rddId + " cannot be found in block manager, " +
+ "cannot start cleaner test"
+ )
+ }
+
+ // Verify the shuffle ids are registered and blocks are present
+ shuffleIds.foreach { shuffleId =>
+ assert(
+ mapOutputTrackerMaster.containsShuffle(shuffleId),
+ "Shuffle " + shuffleId + " have not been registered, cannot start cleaner test"
+ )
+
+ assert(
+ !getShuffleBlocks(shuffleId).isEmpty,
+ "Blocks of shuffle " + shuffleId + " cannot be found in block manager, " +
+ "cannot start cleaner test"
+ )
+ }
+
+ // Verify that the broadcast blocks are present
+ broadcastIds.foreach { broadcastId =>
+ assert(
+ !getBroadcastBlocks(broadcastId).isEmpty,
+ "Blocks of broadcast " + broadcastId + "cannot be found in block manager, " +
+ "cannot start cleaner test"
+ )
+ }
+ }
+
+ /**
+ * Verify that RDDs, shuffles, etc. do not occupy resources. Tests multiple times as there is
+ * as there is not guarantee on how long it will take clean up the resources.
+ */
+ private def postCleanupValidate() {
+ // Verify the RDDs have been persisted and blocks are present
+ rddIds.foreach { rddId =>
+ assert(
+ !sc.persistentRdds.contains(rddId),
+ "RDD " + rddId + " was not cleared from sc.persistentRdds"
+ )
+
+ assert(
+ getRDDBlocks(rddId).isEmpty,
+ "Blocks of RDD " + rddId + " were not cleared from block manager"
+ )
+ }
+
+ // Verify the shuffle ids are registered and blocks are present
+ shuffleIds.foreach { shuffleId =>
+ assert(
+ !mapOutputTrackerMaster.containsShuffle(shuffleId),
+ "Shuffle " + shuffleId + " was not deregistered from map output tracker"
+ )
+
+ assert(
+ getShuffleBlocks(shuffleId).isEmpty,
+ "Blocks of shuffle " + shuffleId + " were not cleared from block manager"
+ )
+ }
+
+ // Verify that the broadcast blocks are present
+ broadcastIds.foreach { broadcastId =>
+ assert(
+ getBroadcastBlocks(broadcastId).isEmpty,
+ "Blocks of broadcast " + broadcastId + " were not cleared from block manager"
+ )
+ }
+ }
+
+ private def uncleanedResourcesToString = {
+ s"""
+ |\tRDDs = ${toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")}
+ |\tShuffles = ${toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")}
+ |\tBroadcasts = ${toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")}
+ """.stripMargin
+ }
+
+ private def isAllCleanedUp =
+ toBeCleanedRDDIds.isEmpty &&
+ toBeCleanedShuffleIds.isEmpty &&
+ toBeCleanedBroadcstIds.isEmpty
+
+ private def getRDDBlocks(rddId: Int): Seq[BlockId] = {
+ blockManager.master.getMatchingBlockIds( _ match {
+ case RDDBlockId(`rddId`, _) => true
+ case _ => false
+ }, askSlaves = true)
+ }
+
+ private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = {
+ blockManager.master.getMatchingBlockIds( _ match {
+ case ShuffleBlockId(`shuffleId`, _, _) => true
+ case _ => false
+ }, askSlaves = true)
+ }
+
+ private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = {
+ blockManager.master.getMatchingBlockIds( _ match {
+ case BroadcastBlockId(`broadcastId`, _) => true
+ case _ => false
+ }, askSlaves = true)
+ }
+
+ private def blockManager = sc.env.blockManager
+ private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/11eabbe1/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index a5bd72e..6b2571c 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -57,12 +57,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.stop()
}
- test("master register and fetch") {
+ test("master register shuffle and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
+ assert(tracker.containsShuffle(10))
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
@@ -77,7 +78,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.stop()
}
- test("master register and unregister and fetch") {
+ test("master register and unregister shuffle") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTrackerMaster(conf)
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
+ tracker.registerShuffle(10, 2)
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val compressedSize10000 = MapOutputTracker.compressSize(10000L)
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+ Array(compressedSize1000, compressedSize10000)))
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+ Array(compressedSize10000, compressedSize1000)))
+ assert(tracker.containsShuffle(10))
+ assert(tracker.getServerStatuses(10, 0).nonEmpty)
+ tracker.unregisterShuffle(10)
+ assert(!tracker.containsShuffle(10))
+ assert(tracker.getServerStatuses(10, 0).isEmpty)
+ }
+
+ test("master register shuffle and unregister map output and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor =
@@ -114,7 +133,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
securityManager = new SecurityManager(conf))
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)