You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/04/04 20:52:08 UTC
[1/2] spark git commit: [SPARK-6602][Core] Replace direct use of Akka
with Spark RPC interface - part 1
Repository: spark
Updated Branches:
refs/heads/master 7bca62f79 -> f15806a8f
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
deleted file mode 100644
index 5b53280..0000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ /dev/null
@@ -1,512 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.util.{HashMap => JHashMap}
-
-import scala.collection.mutable
-import scala.collection.JavaConversions._
-import scala.concurrent.Future
-import scala.concurrent.duration._
-
-import akka.actor.{Actor, ActorRef}
-import akka.pattern.ask
-
-import org.apache.spark.{Logging, SparkConf, SparkException}
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.scheduler._
-import org.apache.spark.storage.BlockManagerMessages._
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils}
-
-/**
- * BlockManagerMasterActor is an actor on the master node to track statuses of
- * all slaves' block managers.
- */
-private[spark]
-class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus)
- extends Actor with ActorLogReceive with Logging {
-
- // Mapping from block manager id to the block manager's information.
- private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]
-
- // Mapping from executor ID to block manager ID.
- private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
-
- // Mapping from block id to the set of block managers that have the block.
- private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
-
- private val akkaTimeout = AkkaUtils.askTimeout(conf)
-
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
- case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
- register(blockManagerId, maxMemSize, slaveActor)
- sender ! true
-
- case UpdateBlockInfo(
- blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) =>
- sender ! updateBlockInfo(
- blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)
-
- case GetLocations(blockId) =>
- sender ! getLocations(blockId)
-
- case GetLocationsMultipleBlockIds(blockIds) =>
- sender ! getLocationsMultipleBlockIds(blockIds)
-
- case GetPeers(blockManagerId) =>
- sender ! getPeers(blockManagerId)
-
- case GetActorSystemHostPortForExecutor(executorId) =>
- sender ! getActorSystemHostPortForExecutor(executorId)
-
- case GetMemoryStatus =>
- sender ! memoryStatus
-
- case GetStorageStatus =>
- sender ! storageStatus
-
- case GetBlockStatus(blockId, askSlaves) =>
- sender ! blockStatus(blockId, askSlaves)
-
- case GetMatchingBlockIds(filter, askSlaves) =>
- sender ! getMatchingBlockIds(filter, askSlaves)
-
- case RemoveRdd(rddId) =>
- sender ! removeRdd(rddId)
-
- case RemoveShuffle(shuffleId) =>
- sender ! removeShuffle(shuffleId)
-
- case RemoveBroadcast(broadcastId, removeFromDriver) =>
- sender ! removeBroadcast(broadcastId, removeFromDriver)
-
- case RemoveBlock(blockId) =>
- removeBlockFromWorkers(blockId)
- sender ! true
-
- case RemoveExecutor(execId) =>
- removeExecutor(execId)
- sender ! true
-
- case StopBlockManagerMaster =>
- sender ! true
- context.stop(self)
-
- case BlockManagerHeartbeat(blockManagerId) =>
- sender ! heartbeatReceived(blockManagerId)
-
- case other =>
- logWarning("Got unknown message: " + other)
- }
-
- private def removeRdd(rddId: Int): Future[Seq[Int]] = {
- // First remove the metadata for the given RDD, and then asynchronously remove the blocks
- // from the slaves.
-
- // Find all blocks for the given RDD, remove the block from both blockLocations and
- // the blockManagerInfo that is tracking the blocks.
- val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
- blocks.foreach { blockId =>
- val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
- bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
- blockLocations.remove(blockId)
- }
-
- // Ask the slaves to remove the RDD, and put the result in a sequence of Futures.
- // The dispatcher is used as an implicit argument into the Future sequence construction.
- import context.dispatcher
- val removeMsg = RemoveRdd(rddId)
- Future.sequence(
- blockManagerInfo.values.map { bm =>
- bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
- }.toSeq
- )
- }
-
- private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
- // Nothing to do in the BlockManagerMasterActor data structures
- import context.dispatcher
- val removeMsg = RemoveShuffle(shuffleId)
- Future.sequence(
- blockManagerInfo.values.map { bm =>
- bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean]
- }.toSeq
- )
- }
-
- /**
- * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
- * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
- * from the executors, but not from the driver.
- */
- private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
- import context.dispatcher
- val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
- val requiredBlockManagers = blockManagerInfo.values.filter { info =>
- removeFromDriver || !info.blockManagerId.isDriver
- }
- Future.sequence(
- requiredBlockManagers.map { bm =>
- bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
- }.toSeq
- )
- }
-
- private def removeBlockManager(blockManagerId: BlockManagerId) {
- val info = blockManagerInfo(blockManagerId)
-
- // Remove the block manager from blockManagerIdByExecutor.
- blockManagerIdByExecutor -= blockManagerId.executorId
-
- // Remove it from blockManagerInfo and remove all the blocks.
- blockManagerInfo.remove(blockManagerId)
- val iterator = info.blocks.keySet.iterator
- while (iterator.hasNext) {
- val blockId = iterator.next
- val locations = blockLocations.get(blockId)
- locations -= blockManagerId
- if (locations.size == 0) {
- blockLocations.remove(blockId)
- }
- }
- listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId))
- logInfo(s"Removing block manager $blockManagerId")
- }
-
- private def removeExecutor(execId: String) {
- logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
- blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
- }
-
- /**
- * Return true if the driver knows about the given block manager. Otherwise, return false,
- * indicating that the block manager should re-register.
- */
- private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = {
- if (!blockManagerInfo.contains(blockManagerId)) {
- blockManagerId.isDriver && !isLocal
- } else {
- blockManagerInfo(blockManagerId).updateLastSeenMs()
- true
- }
- }
-
- // Remove a block from the slaves that have it. This can only be used to remove
- // blocks that the master knows about.
- private def removeBlockFromWorkers(blockId: BlockId) {
- val locations = blockLocations.get(blockId)
- if (locations != null) {
- locations.foreach { blockManagerId: BlockManagerId =>
- val blockManager = blockManagerInfo.get(blockManagerId)
- if (blockManager.isDefined) {
- // Remove the block from the slave's BlockManager.
- // Doesn't actually wait for a confirmation and the message might get lost.
- // If message loss becomes frequent, we should add retry logic here.
- blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout)
- }
- }
- }
- }
-
- // Return a map from the block manager id to max memory and remaining memory.
- private def memoryStatus: Map[BlockManagerId, (Long, Long)] = {
- blockManagerInfo.map { case(blockManagerId, info) =>
- (blockManagerId, (info.maxMem, info.remainingMem))
- }.toMap
- }
-
- private def storageStatus: Array[StorageStatus] = {
- blockManagerInfo.map { case (blockManagerId, info) =>
- new StorageStatus(blockManagerId, info.maxMem, info.blocks)
- }.toArray
- }
-
- /**
- * Return the block's status for all block managers, if any. NOTE: This is a
- * potentially expensive operation and should only be used for testing.
- *
- * If askSlaves is true, the master queries each block manager for the most updated block
- * statuses. This is useful when the master is not informed of the given block by all block
- * managers.
- */
- private def blockStatus(
- blockId: BlockId,
- askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
- import context.dispatcher
- val getBlockStatus = GetBlockStatus(blockId)
- /*
- * Rather than blocking on the block status query, master actor should simply return
- * Futures to avoid potential deadlocks. This can arise if there exists a block manager
- * that is also waiting for this master actor's response to a previous message.
- */
- blockManagerInfo.values.map { info =>
- val blockStatusFuture =
- if (askSlaves) {
- info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]]
- } else {
- Future { info.getStatus(blockId) }
- }
- (info.blockManagerId, blockStatusFuture)
- }.toMap
- }
-
- /**
- * Return the ids of blocks present in all the block managers that match the given filter.
- * NOTE: This is a potentially expensive operation and should only be used for testing.
- *
- * If askSlaves is true, the master queries each block manager for the most updated block
- * statuses. This is useful when the master is not informed of the given block by all block
- * managers.
- */
- private def getMatchingBlockIds(
- filter: BlockId => Boolean,
- askSlaves: Boolean): Future[Seq[BlockId]] = {
- import context.dispatcher
- val getMatchingBlockIds = GetMatchingBlockIds(filter)
- Future.sequence(
- blockManagerInfo.values.map { info =>
- val future =
- if (askSlaves) {
- info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]]
- } else {
- Future { info.blocks.keys.filter(filter).toSeq }
- }
- future
- }
- ).map(_.flatten.toSeq)
- }
-
- private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
- val time = System.currentTimeMillis()
- if (!blockManagerInfo.contains(id)) {
- blockManagerIdByExecutor.get(id.executorId) match {
- case Some(oldId) =>
- // A block manager of the same executor already exists, so remove it (assumed dead)
- logError("Got two different block manager registrations on same executor - "
- + s" will replace old one $oldId with new one $id")
- removeExecutor(id.executorId)
- case None =>
- }
- logInfo("Registering block manager %s with %s RAM, %s".format(
- id.hostPort, Utils.bytesToString(maxMemSize), id))
-
- blockManagerIdByExecutor(id.executorId) = id
-
- blockManagerInfo(id) = new BlockManagerInfo(
- id, System.currentTimeMillis(), maxMemSize, slaveActor)
- }
- listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
- }
-
- private def updateBlockInfo(
- blockManagerId: BlockManagerId,
- blockId: BlockId,
- storageLevel: StorageLevel,
- memSize: Long,
- diskSize: Long,
- tachyonSize: Long): Boolean = {
-
- if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.isDriver && !isLocal) {
- // We intentionally do not register the master (except in local mode),
- // so we should not indicate failure.
- return true
- } else {
- return false
- }
- }
-
- if (blockId == null) {
- blockManagerInfo(blockManagerId).updateLastSeenMs()
- return true
- }
-
- blockManagerInfo(blockManagerId).updateBlockInfo(
- blockId, storageLevel, memSize, diskSize, tachyonSize)
-
- var locations: mutable.HashSet[BlockManagerId] = null
- if (blockLocations.containsKey(blockId)) {
- locations = blockLocations.get(blockId)
- } else {
- locations = new mutable.HashSet[BlockManagerId]
- blockLocations.put(blockId, locations)
- }
-
- if (storageLevel.isValid) {
- locations.add(blockManagerId)
- } else {
- locations.remove(blockManagerId)
- }
-
- // Remove the block from master tracking if it has been removed on all slaves.
- if (locations.size == 0) {
- blockLocations.remove(blockId)
- }
- true
- }
-
- private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
- if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
- }
-
- private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
- blockIds.map(blockId => getLocations(blockId))
- }
-
- /** Get the list of the peers of the given block manager */
- private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
- val blockManagerIds = blockManagerInfo.keySet
- if (blockManagerIds.contains(blockManagerId)) {
- blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq
- } else {
- Seq.empty
- }
- }
-
- /**
- * Returns the hostname and port of an executor's actor system, based on the Akka address of its
- * BlockManagerSlaveActor.
- */
- private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
- for (
- blockManagerId <- blockManagerIdByExecutor.get(executorId);
- info <- blockManagerInfo.get(blockManagerId);
- host <- info.slaveActor.path.address.host;
- port <- info.slaveActor.path.address.port
- ) yield {
- (host, port)
- }
- }
-}
-
-@DeveloperApi
-case class BlockStatus(
- storageLevel: StorageLevel,
- memSize: Long,
- diskSize: Long,
- tachyonSize: Long) {
- def isCached: Boolean = memSize + diskSize + tachyonSize > 0
-}
-
-@DeveloperApi
-object BlockStatus {
- def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)
-}
-
-private[spark] class BlockManagerInfo(
- val blockManagerId: BlockManagerId,
- timeMs: Long,
- val maxMem: Long,
- val slaveActor: ActorRef)
- extends Logging {
-
- private var _lastSeenMs: Long = timeMs
- private var _remainingMem: Long = maxMem
-
- // Mapping from block id to its status.
- private val _blocks = new JHashMap[BlockId, BlockStatus]
-
- def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId))
-
- def updateLastSeenMs() {
- _lastSeenMs = System.currentTimeMillis()
- }
-
- def updateBlockInfo(
- blockId: BlockId,
- storageLevel: StorageLevel,
- memSize: Long,
- diskSize: Long,
- tachyonSize: Long) {
-
- updateLastSeenMs()
-
- if (_blocks.containsKey(blockId)) {
- // The block exists on the slave already.
- val blockStatus: BlockStatus = _blocks.get(blockId)
- val originalLevel: StorageLevel = blockStatus.storageLevel
- val originalMemSize: Long = blockStatus.memSize
-
- if (originalLevel.useMemory) {
- _remainingMem += originalMemSize
- }
- }
-
- if (storageLevel.isValid) {
- /* isValid means it is either stored in-memory, on-disk or on-Tachyon.
- * The memSize here indicates the data size in or dropped from memory,
- * tachyonSize here indicates the data size in or dropped from Tachyon,
- * and the diskSize here indicates the data size in or dropped to disk.
- * They can be both larger than 0, when a block is dropped from memory to disk.
- * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */
- if (storageLevel.useMemory) {
- _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0))
- _remainingMem -= memSize
- logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(memSize),
- Utils.bytesToString(_remainingMem)))
- }
- if (storageLevel.useDisk) {
- _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0))
- logInfo("Added %s on disk on %s (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize)))
- }
- if (storageLevel.useOffHeap) {
- _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize))
- logInfo("Added %s on tachyon on %s (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize)))
- }
- } else if (_blocks.containsKey(blockId)) {
- // If isValid is not true, drop the block.
- val blockStatus: BlockStatus = _blocks.get(blockId)
- _blocks.remove(blockId)
- if (blockStatus.storageLevel.useMemory) {
- logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize),
- Utils.bytesToString(_remainingMem)))
- }
- if (blockStatus.storageLevel.useDisk) {
- logInfo("Removed %s on %s on disk (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize)))
- }
- if (blockStatus.storageLevel.useOffHeap) {
- logInfo("Removed %s on %s on tachyon (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize)))
- }
- }
- }
-
- def removeBlock(blockId: BlockId) {
- if (_blocks.containsKey(blockId)) {
- _remainingMem += _blocks.get(blockId).memSize
- _blocks.remove(blockId)
- }
- }
-
- def remainingMem: Long = _remainingMem
-
- def lastSeenMs: Long = _lastSeenMs
-
- def blocks: JHashMap[BlockId, BlockStatus] = _blocks
-
- override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
-
- def clear() {
- _blocks.clear()
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
new file mode 100644
index 0000000..28c73a7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -0,0 +1,509 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable
+import scala.collection.JavaConversions._
+import scala.concurrent.{ExecutionContext, Future}
+
+import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint}
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.scheduler._
+import org.apache.spark.storage.BlockManagerMessages._
+import org.apache.spark.util.Utils
+
+/**
+ * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses
+ * of all slaves' block managers.
+ */
+private[spark]
+class BlockManagerMasterEndpoint(
+ override val rpcEnv: RpcEnv,
+ val isLocal: Boolean,
+ conf: SparkConf,
+ listenerBus: LiveListenerBus)
+ extends ThreadSafeRpcEndpoint with Logging {
+
+ // Mapping from block manager id to the block manager's information.
+ private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]
+
+ // Mapping from executor ID to block manager ID.
+ private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
+
+ // Mapping from block id to the set of block managers that have the block.
+ private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
+
+ private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool")
+ private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) =>
+ register(blockManagerId, maxMemSize, slaveEndpoint)
+ context.reply(true)
+
+ case UpdateBlockInfo(
+ blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) =>
+ context.reply(updateBlockInfo(
+ blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize))
+
+ case GetLocations(blockId) =>
+ context.reply(getLocations(blockId))
+
+ case GetLocationsMultipleBlockIds(blockIds) =>
+ context.reply(getLocationsMultipleBlockIds(blockIds))
+
+ case GetPeers(blockManagerId) =>
+ context.reply(getPeers(blockManagerId))
+
+ case GetRpcHostPortForExecutor(executorId) =>
+ context.reply(getRpcHostPortForExecutor(executorId))
+
+ case GetMemoryStatus =>
+ context.reply(memoryStatus)
+
+ case GetStorageStatus =>
+ context.reply(storageStatus)
+
+ case GetBlockStatus(blockId, askSlaves) =>
+ context.reply(blockStatus(blockId, askSlaves))
+
+ case GetMatchingBlockIds(filter, askSlaves) =>
+ context.reply(getMatchingBlockIds(filter, askSlaves))
+
+ case RemoveRdd(rddId) =>
+ context.reply(removeRdd(rddId))
+
+ case RemoveShuffle(shuffleId) =>
+ context.reply(removeShuffle(shuffleId))
+
+ case RemoveBroadcast(broadcastId, removeFromDriver) =>
+ context.reply(removeBroadcast(broadcastId, removeFromDriver))
+
+ case RemoveBlock(blockId) =>
+ removeBlockFromWorkers(blockId)
+ context.reply(true)
+
+ case RemoveExecutor(execId) =>
+ removeExecutor(execId)
+ context.reply(true)
+
+ case StopBlockManagerMaster =>
+ context.reply(true)
+ stop()
+
+ case BlockManagerHeartbeat(blockManagerId) =>
+ context.reply(heartbeatReceived(blockManagerId))
+
+ }
+
+ private def removeRdd(rddId: Int): Future[Seq[Int]] = {
+ // First remove the metadata for the given RDD, and then asynchronously remove the blocks
+ // from the slaves.
+
+ // Find all blocks for the given RDD, remove the block from both blockLocations and
+ // the blockManagerInfo that is tracking the blocks.
+ val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
+ blocks.foreach { blockId =>
+ val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
+ bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
+ blockLocations.remove(blockId)
+ }
+
+ // Ask the slaves to remove the RDD, and put the result in a sequence of Futures.
+ // The dispatcher is used as an implicit argument into the Future sequence construction.
+ val removeMsg = RemoveRdd(rddId)
+ Future.sequence(
+ blockManagerInfo.values.map { bm =>
+ bm.slaveEndpoint.sendWithReply[Int](removeMsg)
+ }.toSeq
+ )
+ }
+
+ private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
+ // Nothing to do in the BlockManagerMasterEndpoint data structures
+ val removeMsg = RemoveShuffle(shuffleId)
+ Future.sequence(
+ blockManagerInfo.values.map { bm =>
+ bm.slaveEndpoint.sendWithReply[Boolean](removeMsg)
+ }.toSeq
+ )
+ }
+
+ /**
+ * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
+ * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
+ * from the executors, but not from the driver.
+ */
+ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
+ val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
+ val requiredBlockManagers = blockManagerInfo.values.filter { info =>
+ removeFromDriver || !info.blockManagerId.isDriver
+ }
+ Future.sequence(
+ requiredBlockManagers.map { bm =>
+ bm.slaveEndpoint.sendWithReply[Int](removeMsg)
+ }.toSeq
+ )
+ }
+
+ private def removeBlockManager(blockManagerId: BlockManagerId) {
+ val info = blockManagerInfo(blockManagerId)
+
+ // Remove the block manager from blockManagerIdByExecutor.
+ blockManagerIdByExecutor -= blockManagerId.executorId
+
+ // Remove it from blockManagerInfo and remove all the blocks.
+ blockManagerInfo.remove(blockManagerId)
+ val iterator = info.blocks.keySet.iterator
+ while (iterator.hasNext) {
+ val blockId = iterator.next
+ val locations = blockLocations.get(blockId)
+ locations -= blockManagerId
+ if (locations.size == 0) {
+ blockLocations.remove(blockId)
+ }
+ }
+ listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId))
+ logInfo(s"Removing block manager $blockManagerId")
+ }
+
+ private def removeExecutor(execId: String) {
+ logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
+ blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
+ }
+
+ /**
+ * Return true if the driver knows about the given block manager. Otherwise, return false,
+ * indicating that the block manager should re-register.
+ */
+ private def heartbeatReceived(blockManagerId: BlockManagerId): Boolean = {
+ if (!blockManagerInfo.contains(blockManagerId)) {
+ blockManagerId.isDriver && !isLocal
+ } else {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ true
+ }
+ }
+
+ // Remove a block from the slaves that have it. This can only be used to remove
+ // blocks that the master knows about.
+ private def removeBlockFromWorkers(blockId: BlockId) {
+ val locations = blockLocations.get(blockId)
+ if (locations != null) {
+ locations.foreach { blockManagerId: BlockManagerId =>
+ val blockManager = blockManagerInfo.get(blockManagerId)
+ if (blockManager.isDefined) {
+ // Remove the block from the slave's BlockManager.
+ // Doesn't actually wait for a confirmation and the message might get lost.
+ // If message loss becomes frequent, we should add retry logic here.
+ blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId))
+ }
+ }
+ }
+ }
+
+ // Return a map from the block manager id to max memory and remaining memory.
+ private def memoryStatus: Map[BlockManagerId, (Long, Long)] = {
+ blockManagerInfo.map { case(blockManagerId, info) =>
+ (blockManagerId, (info.maxMem, info.remainingMem))
+ }.toMap
+ }
+
+ private def storageStatus: Array[StorageStatus] = {
+ blockManagerInfo.map { case (blockManagerId, info) =>
+ new StorageStatus(blockManagerId, info.maxMem, info.blocks)
+ }.toArray
+ }
+
+ /**
+ * Return the block's status for all block managers, if any. NOTE: This is a
+ * potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, the master queries each block manager for the most updated block
+ * statuses. This is useful when the master is not informed of the given block by all block
+ * managers.
+ */
+ private def blockStatus(
+ blockId: BlockId,
+ askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
+ val getBlockStatus = GetBlockStatus(blockId)
+ /*
+ * Rather than blocking on the block status query, master endpoint should simply return
+ * Futures to avoid potential deadlocks. This can arise if there exists a block manager
+ * that is also waiting for this master endpoint's response to a previous message.
+ */
+ blockManagerInfo.values.map { info =>
+ val blockStatusFuture =
+ if (askSlaves) {
+ info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus)
+ } else {
+ Future { info.getStatus(blockId) }
+ }
+ (info.blockManagerId, blockStatusFuture)
+ }.toMap
+ }
+
+ /**
+ * Return the ids of blocks present in all the block managers that match the given filter.
+ * NOTE: This is a potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, the master queries each block manager for the most updated block
+ * statuses. This is useful when the master is not informed of the given block by all block
+ * managers.
+ */
+ private def getMatchingBlockIds(
+ filter: BlockId => Boolean,
+ askSlaves: Boolean): Future[Seq[BlockId]] = {
+ val getMatchingBlockIds = GetMatchingBlockIds(filter)
+ Future.sequence(
+ blockManagerInfo.values.map { info =>
+ val future =
+ if (askSlaves) {
+ info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds)
+ } else {
+ Future { info.blocks.keys.filter(filter).toSeq }
+ }
+ future
+ }
+ ).map(_.flatten.toSeq)
+ }
+
+ private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) {
+ val time = System.currentTimeMillis()
+ if (!blockManagerInfo.contains(id)) {
+ blockManagerIdByExecutor.get(id.executorId) match {
+ case Some(oldId) =>
+ // A block manager of the same executor already exists, so remove it (assumed dead)
+ logError("Got two different block manager registrations on same executor - "
+ + s" will replace old one $oldId with new one $id")
+ removeExecutor(id.executorId)
+ case None =>
+ }
+ logInfo("Registering block manager %s with %s RAM, %s".format(
+ id.hostPort, Utils.bytesToString(maxMemSize), id))
+
+ blockManagerIdByExecutor(id.executorId) = id
+
+ blockManagerInfo(id) = new BlockManagerInfo(
+ id, System.currentTimeMillis(), maxMemSize, slaveEndpoint)
+ }
+ listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
+ }
+
+ private def updateBlockInfo(
+ blockManagerId: BlockManagerId,
+ blockId: BlockId,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ tachyonSize: Long): Boolean = {
+
+ if (!blockManagerInfo.contains(blockManagerId)) {
+ if (blockManagerId.isDriver && !isLocal) {
+ // We intentionally do not register the master (except in local mode),
+ // so we should not indicate failure.
+ return true
+ } else {
+ return false
+ }
+ }
+
+ if (blockId == null) {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ return true
+ }
+
+ blockManagerInfo(blockManagerId).updateBlockInfo(
+ blockId, storageLevel, memSize, diskSize, tachyonSize)
+
+ var locations: mutable.HashSet[BlockManagerId] = null
+ if (blockLocations.containsKey(blockId)) {
+ locations = blockLocations.get(blockId)
+ } else {
+ locations = new mutable.HashSet[BlockManagerId]
+ blockLocations.put(blockId, locations)
+ }
+
+ if (storageLevel.isValid) {
+ locations.add(blockManagerId)
+ } else {
+ locations.remove(blockManagerId)
+ }
+
+ // Remove the block from master tracking if it has been removed on all slaves.
+ if (locations.size == 0) {
+ blockLocations.remove(blockId)
+ }
+ true
+ }
+
+ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
+ if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
+ }
+
+ private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
+ blockIds.map(blockId => getLocations(blockId))
+ }
+
+ /** Get the list of the peers of the given block manager */
+ private def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
+ val blockManagerIds = blockManagerInfo.keySet
+ if (blockManagerIds.contains(blockManagerId)) {
+ blockManagerIds.filterNot { _.isDriver }.filterNot { _ == blockManagerId }.toSeq
+ } else {
+ Seq.empty
+ }
+ }
+
+ /**
+ * Returns the hostname and port of an executor, based on the [[RpcEnv]] address of its
+ * [[BlockManagerSlaveEndpoint]].
+ */
+ private def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ for (
+ blockManagerId <- blockManagerIdByExecutor.get(executorId);
+ info <- blockManagerInfo.get(blockManagerId)
+ ) yield {
+ (info.slaveEndpoint.address.host, info.slaveEndpoint.address.port)
+ }
+ }
+
+ override def onStop(): Unit = {
+ askThreadPool.shutdownNow()
+ }
+}
+
+@DeveloperApi
+case class BlockStatus(
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ tachyonSize: Long) {
+ def isCached: Boolean = memSize + diskSize + tachyonSize > 0
+}
+
+@DeveloperApi
+object BlockStatus {
+ def empty: BlockStatus = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)
+}
+
+private[spark] class BlockManagerInfo(
+ val blockManagerId: BlockManagerId,
+ timeMs: Long,
+ val maxMem: Long,
+ val slaveEndpoint: RpcEndpointRef)
+ extends Logging {
+
+ private var _lastSeenMs: Long = timeMs
+ private var _remainingMem: Long = maxMem
+
+ // Mapping from block id to its status.
+ private val _blocks = new JHashMap[BlockId, BlockStatus]
+
+ def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId))
+
+ def updateLastSeenMs() {
+ _lastSeenMs = System.currentTimeMillis()
+ }
+
+ def updateBlockInfo(
+ blockId: BlockId,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ tachyonSize: Long) {
+
+ updateLastSeenMs()
+
+ if (_blocks.containsKey(blockId)) {
+ // The block exists on the slave already.
+ val blockStatus: BlockStatus = _blocks.get(blockId)
+ val originalLevel: StorageLevel = blockStatus.storageLevel
+ val originalMemSize: Long = blockStatus.memSize
+
+ if (originalLevel.useMemory) {
+ _remainingMem += originalMemSize
+ }
+ }
+
+ if (storageLevel.isValid) {
+ /* isValid means it is either stored in-memory, on-disk or on-Tachyon.
+ * The memSize here indicates the data size in or dropped from memory,
+ * tachyonSize here indicates the data size in or dropped from Tachyon,
+ * and the diskSize here indicates the data size in or dropped to disk.
+ * They can be both larger than 0, when a block is dropped from memory to disk.
+ * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */
+ if (storageLevel.useMemory) {
+ _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0))
+ _remainingMem -= memSize
+ logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(memSize),
+ Utils.bytesToString(_remainingMem)))
+ }
+ if (storageLevel.useDisk) {
+ _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0))
+ logInfo("Added %s on disk on %s (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize)))
+ }
+ if (storageLevel.useOffHeap) {
+ _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize))
+ logInfo("Added %s on tachyon on %s (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize)))
+ }
+ } else if (_blocks.containsKey(blockId)) {
+ // If isValid is not true, drop the block.
+ val blockStatus: BlockStatus = _blocks.get(blockId)
+ _blocks.remove(blockId)
+ if (blockStatus.storageLevel.useMemory) {
+ logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize),
+ Utils.bytesToString(_remainingMem)))
+ }
+ if (blockStatus.storageLevel.useDisk) {
+ logInfo("Removed %s on %s on disk (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize)))
+ }
+ if (blockStatus.storageLevel.useOffHeap) {
+ logInfo("Removed %s on %s on tachyon (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize)))
+ }
+ }
+ }
+
+ def removeBlock(blockId: BlockId) {
+ if (_blocks.containsKey(blockId)) {
+ _remainingMem += _blocks.get(blockId).memSize
+ _blocks.remove(blockId)
+ }
+ }
+
+ def remainingMem: Long = _remainingMem
+
+ def lastSeenMs: Long = _lastSeenMs
+
+ def blocks: JHashMap[BlockId, BlockStatus] = _blocks
+
+ override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
+
+ def clear() {
+ _blocks.clear()
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 4824745..f89d8d7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -19,8 +19,7 @@ package org.apache.spark.storage
import java.io.{Externalizable, ObjectInput, ObjectOutput}
-import akka.actor.ActorRef
-
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.Utils
private[spark] object BlockManagerMessages {
@@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages {
case class RegisterBlockManager(
blockManagerId: BlockManagerId,
maxMemSize: Long,
- sender: ActorRef)
+ sender: RpcEndpointRef)
extends ToBlockManagerMaster
case class UpdateBlockInfo(
@@ -92,7 +91,7 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
- case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+ case class GetRpcHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
deleted file mode 100644
index 52fb896..0000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import scala.concurrent.Future
-
-import akka.actor.{ActorRef, Actor}
-
-import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
-import org.apache.spark.storage.BlockManagerMessages._
-import org.apache.spark.util.ActorLogReceive
-
-/**
- * An actor to take commands from the master to execute options. For example,
- * this is used to remove blocks from the slave's BlockManager.
- */
-private[storage]
-class BlockManagerSlaveActor(
- blockManager: BlockManager,
- mapOutputTracker: MapOutputTracker)
- extends Actor with ActorLogReceive with Logging {
-
- import context.dispatcher
-
- // Operations that involve removing blocks may be slow and should be done asynchronously
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
- case RemoveBlock(blockId) =>
- doAsync[Boolean]("removing block " + blockId, sender) {
- blockManager.removeBlock(blockId)
- true
- }
-
- case RemoveRdd(rddId) =>
- doAsync[Int]("removing RDD " + rddId, sender) {
- blockManager.removeRdd(rddId)
- }
-
- case RemoveShuffle(shuffleId) =>
- doAsync[Boolean]("removing shuffle " + shuffleId, sender) {
- if (mapOutputTracker != null) {
- mapOutputTracker.unregisterShuffle(shuffleId)
- }
- SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
- }
-
- case RemoveBroadcast(broadcastId, _) =>
- doAsync[Int]("removing broadcast " + broadcastId, sender) {
- blockManager.removeBroadcast(broadcastId, tellMaster = true)
- }
-
- case GetBlockStatus(blockId, _) =>
- sender ! blockManager.getStatus(blockId)
-
- case GetMatchingBlockIds(filter, _) =>
- sender ! blockManager.getMatchingBlockIds(filter)
- }
-
- private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) {
- val future = Future {
- logDebug(actionMessage)
- body
- }
- future.onSuccess { case response =>
- logDebug("Done " + actionMessage + ", response is " + response)
- responseActor ! response
- logDebug("Sent response: " + response + " to " + responseActor)
- }
- future.onFailure { case t: Throwable =>
- logError("Error in " + actionMessage, t)
- responseActor ! null.asInstanceOf[T]
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
new file mode 100644
index 0000000..8980fa8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.concurrent.{ExecutionContext, Future}
+
+import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
+import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
+import org.apache.spark.storage.BlockManagerMessages._
+
+/**
+ * An RpcEndpoint to take commands from the master to execute options. For example,
+ * this is used to remove blocks from the slave's BlockManager.
+ */
+private[storage]
+class BlockManagerSlaveEndpoint(
+ override val rpcEnv: RpcEnv,
+ blockManager: BlockManager,
+ mapOutputTracker: MapOutputTracker)
+ extends RpcEndpoint with Logging {
+
+ private val asyncThreadPool =
+ Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool")
+ private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool)
+
+ // Operations that involve removing blocks may be slow and should be done asynchronously
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RemoveBlock(blockId) =>
+ doAsync[Boolean]("removing block " + blockId, context) {
+ blockManager.removeBlock(blockId)
+ true
+ }
+
+ case RemoveRdd(rddId) =>
+ doAsync[Int]("removing RDD " + rddId, context) {
+ blockManager.removeRdd(rddId)
+ }
+
+ case RemoveShuffle(shuffleId) =>
+ doAsync[Boolean]("removing shuffle " + shuffleId, context) {
+ if (mapOutputTracker != null) {
+ mapOutputTracker.unregisterShuffle(shuffleId)
+ }
+ SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
+ }
+
+ case RemoveBroadcast(broadcastId, _) =>
+ doAsync[Int]("removing broadcast " + broadcastId, context) {
+ blockManager.removeBroadcast(broadcastId, tellMaster = true)
+ }
+
+ case GetBlockStatus(blockId, _) =>
+ context.reply(blockManager.getStatus(blockId))
+
+ case GetMatchingBlockIds(filter, _) =>
+ context.reply(blockManager.getMatchingBlockIds(filter))
+ }
+
+ private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) {
+ val future = Future {
+ logDebug(actionMessage)
+ body
+ }
+ future.onSuccess { case response =>
+ logDebug("Done " + actionMessage + ", response is " + response)
+ context.reply(response)
+ logDebug("Sent response: " + response + " to " + context.sender)
+ }
+ future.onFailure { case t: Throwable =>
+ logError("Error in " + actionMessage, t)
+ context.sendFailure(t)
+ }
+ }
+
+ override def onStop(): Unit = {
+ asyncThreadPool.shutdownNow()
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 7c85e28..0fdfaf3 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1214,6 +1214,16 @@ private[spark] object Utils extends Logging {
}
}
+ /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */
+ def tryLogNonFatalError(block: => Unit) {
+ try {
+ block
+ } catch {
+ case NonFatal(t) =>
+ logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
+ }
+ }
+
/**
* Execute a block of code, then a finally block, but if exceptions happen in
* the finally block, do not suppress the original exception.
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
new file mode 100644
index 0000000..0fd570e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.BlockManagerId
+import org.scalatest.FunSuite
+import org.mockito.Mockito.{mock, spy, verify, when}
+import org.mockito.Matchers
+import org.mockito.Matchers._
+
+import org.apache.spark.scheduler.TaskScheduler
+import org.apache.spark.util.RpcUtils
+import org.scalatest.concurrent.Eventually._
+
+class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext {
+
+ test("HeartbeatReceiver") {
+ sc = spy(new SparkContext("local[2]", "test"))
+ val scheduler = mock(classOf[TaskScheduler])
+ when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
+ when(sc.taskScheduler).thenReturn(scheduler)
+
+ val heartbeatReceiver = new HeartbeatReceiver(sc)
+ sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(heartbeatReceiver.scheduler != null)
+ }
+ val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+
+ val metrics = new TaskMetrics
+ val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
+ val response = receiverRef.askWithReply[HeartbeatResponse](
+ Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+
+ verify(scheduler).executorHeartbeatReceived(
+ Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
+ assert(false === response.reregisterBlockManager)
+ }
+
+ test("HeartbeatReceiver re-register") {
+ sc = spy(new SparkContext("local[2]", "test"))
+ val scheduler = mock(classOf[TaskScheduler])
+ when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false)
+ when(sc.taskScheduler).thenReturn(scheduler)
+
+ val heartbeatReceiver = new HeartbeatReceiver(sc)
+ sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(heartbeatReceiver.scheduler != null)
+ }
+ val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
+
+ val metrics = new TaskMetrics
+ val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
+ val response = receiverRef.askWithReply[HeartbeatResponse](
+ Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
+
+ verify(scheduler).executorHeartbeatReceived(
+ Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
+ assert(true === response.reregisterBlockManager)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index e07bdb9..4f19c4f 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -311,7 +311,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
}
test("self: call in onStop") {
- @volatile var e: Throwable = null
+ @volatile var selfOption: Option[RpcEndpointRef] = null
val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint {
override val rpcEnv = env
@@ -321,20 +321,18 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
}
override def onStop(): Unit = {
- self
+ selfOption = Option(self)
}
override def onError(cause: Throwable): Unit = {
- e = cause
}
})
env.stop(endpointRef)
eventually(timeout(5 seconds), interval(10 millis)) {
- // Calling `self` in `onStop` is invalid
- assert(e != null)
- assert(e.getMessage.contains("Cannot find RpcEndpointRef"))
+ // Calling `self` in `onStop` will return null, so selfOption will be None
+ assert(selfOption == None)
}
}
@@ -342,7 +340,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
// If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it
for(i <- 0 until 100) {
@volatile var result = 0
- val endpointRef = env.setupThreadSafeEndpoint(s"receive-in-sequence-$i", new RpcEndpoint {
+ val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new ThreadSafeRpcEndpoint {
override val rpcEnv = env
override def receive = {
@@ -475,7 +473,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll {
test("network events") {
val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
- env.setupThreadSafeEndpoint("network-events", new RpcEndpoint {
+ env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint {
override val rpcEnv = env
override def receive = {
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index c2903c8..b4de90b 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -22,11 +22,11 @@ import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps
-import akka.actor.{ActorSystem, Props}
import org.mockito.Mockito.{mock, when}
-import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester}
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
import org.scalatest.concurrent.Eventually._
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager}
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
@@ -34,13 +34,12 @@ import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.storage.StorageLevel._
-import org.apache.spark.util.{AkkaUtils, SizeEstimator}
/** Testsuite that tests block replication in BlockManager */
class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter {
private val conf = new SparkConf(false)
- var actorSystem: ActorSystem = null
+ var rpcEnv: RpcEnv = null
var master: BlockManagerMaster = null
val securityMgr = new SecurityManager(conf)
val mapOutputTracker = new MapOutputTrackerMaster(conf)
@@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
- val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
store.initialize("app-id")
allStores += store
@@ -69,12 +68,10 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
}
before {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "test", "localhost", 0, conf = conf, securityManager = securityMgr)
- this.actorSystem = actorSystem
+ rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr)
conf.set("spark.authenticate", "false")
- conf.set("spark.driver.port", boundPort.toString)
+ conf.set("spark.driver.port", rpcEnv.address.port.toString)
conf.set("spark.storage.unrollFraction", "0.4")
conf.set("spark.storage.unrollMemoryThreshold", "512")
@@ -83,18 +80,17 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
// to make cached peers refresh frequently
conf.set("spark.storage.cachedPeersTtl", "10")
- master = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
- conf, true)
+ master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
+ new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true)
allStores.clear()
}
after {
allStores.foreach { _.stop() }
allStores.clear()
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- actorSystem = null
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
+ rpcEnv = null
master = null
}
@@ -262,7 +258,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work
when(failableTransfer.hostName).thenReturn("some-hostname")
when(failableTransfer.port).thenReturn(1000)
- val failableStore = new BlockManager("failable-store", actorSystem, master, serializer,
+ val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer,
10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0)
failableStore.initialize("app-id")
allStores += failableStore // so that this gets stopped after test
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index ecd1cba..283090e 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -19,24 +19,18 @@ package org.apache.spark.storage
import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays
-import java.util.concurrent.TimeUnit
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps
-import akka.actor._
-import akka.pattern.ask
-import akka.util.Timeout
-
import org.mockito.Mockito.{mock, when}
-
import org.scalatest._
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager}
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.network.nio.NioBlockTransferService
@@ -53,7 +47,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
private val conf = new SparkConf(false)
var store: BlockManager = null
var store2: BlockManager = null
- var actorSystem: ActorSystem = null
+ var rpcEnv: RpcEnv = null
var master: BlockManagerMaster = null
conf.set("spark.authenticate", "false")
val securityMgr = new SecurityManager(conf)
@@ -72,28 +66,25 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
- val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
+ val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
manager.initialize("app-id")
manager
}
override def beforeEach(): Unit = {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "test", "localhost", 0, conf = conf, securityManager = securityMgr)
- this.actorSystem = actorSystem
+ rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr)
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
System.setProperty("os.arch", "amd64")
conf.set("os.arch", "amd64")
conf.set("spark.test.useCompressedOops", "true")
- conf.set("spark.driver.port", boundPort.toString)
+ conf.set("spark.driver.port", rpcEnv.address.port.toString)
conf.set("spark.storage.unrollFraction", "0.4")
conf.set("spark.storage.unrollMemoryThreshold", "512")
- master = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
- conf, true)
+ master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
+ new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true)
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
@@ -108,9 +99,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
store2.stop()
store2 = null
}
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- actorSystem = null
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
+ rpcEnv = null
master = null
}
@@ -357,10 +348,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
master.removeExecutor(store.blockManagerId.executorId)
assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
- implicit val timeout = Timeout(30, TimeUnit.SECONDS)
- val reregister = !Await.result(
- master.driverActor ? BlockManagerHeartbeat(store.blockManagerId),
- timeout.duration).asInstanceOf[Boolean]
+ val reregister = !master.driverEndpoint.askWithReply[Boolean](
+ BlockManagerHeartbeat(store.blockManagerId))
assert(reregister == true)
}
@@ -785,7 +774,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
val transfer = new NioBlockTransferService(conf, securityMgr)
- store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master,
+ store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr,
0)
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 18a477f..ef4873d 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -24,20 +24,20 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.language.postfixOps
-import akka.actor.{ActorSystem, Props}
import org.apache.hadoop.conf.Configuration
import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
import org.scalatest.concurrent.Eventually._
import org.apache.spark._
import org.apache.spark.network.nio.NioBlockTransferService
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.storage._
import org.apache.spark.streaming.receiver._
import org.apache.spark.streaming.util._
-import org.apache.spark.util.{AkkaUtils, ManualClock, Utils}
+import org.apache.spark.util.{ManualClock, Utils}
import WriteAheadLogBasedBlockHandler._
import WriteAheadLogSuite._
@@ -54,22 +54,19 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
val manualClock = new ManualClock
val blockManagerSize = 10000000
- var actorSystem: ActorSystem = null
+ var rpcEnv: RpcEnv = null
var blockManagerMaster: BlockManagerMaster = null
var blockManager: BlockManager = null
var tempDirectory: File = null
before {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "test", "localhost", 0, conf = conf, securityManager = securityMgr)
- this.actorSystem = actorSystem
- conf.set("spark.driver.port", boundPort.toString)
+ rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr)
+ conf.set("spark.driver.port", rpcEnv.address.port.toString)
- blockManagerMaster = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))),
- conf, true)
+ blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager",
+ new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true)
- blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer,
+ blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer,
blockManagerSize, conf, mapOutputTracker, shuffleManager,
new NioBlockTransferService(conf, securityMgr), securityMgr, 0)
blockManager.initialize("app-id")
@@ -87,9 +84,9 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
blockManagerMaster.stop()
blockManagerMaster = null
}
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- actorSystem = null
+ rpcEnv.shutdown()
+ rpcEnv.awaitTermination()
+ rpcEnv = null
Utils.deleteRecursively(tempDirectory)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
----------------------------------------------------------------------
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 455554e..24a1e02 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -24,22 +24,20 @@ import java.lang.reflect.InvocationTargetException
import java.net.{Socket, URL}
import java.util.concurrent.atomic.AtomicReference
-import akka.actor._
-import akka.remote._
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.util.ShutdownHookManager
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.spark.rpc._
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.SparkException
import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil}
import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.scheduler.cluster.YarnSchedulerBackend
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader,
- SignalLogger, Utils}
+import org.apache.spark.util._
/**
* Common application master functionality for Spark on Yarn.
@@ -72,8 +70,8 @@ private[spark] class ApplicationMaster(
@volatile private var allocator: YarnAllocator = _
// Fields used in client mode.
- private var actorSystem: ActorSystem = null
- private var actor: ActorRef = _
+ private var rpcEnv: RpcEnv = null
+ private var amEndpoint: RpcEndpointRef = _
// Fields used in cluster mode.
private val sparkContextRef = new AtomicReference[SparkContext](null)
@@ -240,22 +238,21 @@ private[spark] class ApplicationMaster(
}
/**
- * Create an actor that communicates with the driver.
+ * Create an [[RpcEndpoint]] that communicates with the driver.
*
* In cluster mode, the AM and the driver belong to same process
- * so the AM actor need not monitor lifecycle of the driver.
+ * so the AMEndpoint need not monitor lifecycle of the driver.
*/
- private def runAMActor(
+ private def runAMEndpoint(
host: String,
port: String,
isClusterMode: Boolean): Unit = {
- val driverUrl = AkkaUtils.address(
- AkkaUtils.protocol(actorSystem),
+ val driverEndpont = rpcEnv.setupEndpointRef(
SparkEnv.driverActorSystemName,
- host,
- port,
- YarnSchedulerBackend.ACTOR_NAME)
- actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isClusterMode)), name = "YarnAM")
+ RpcAddress(host, port.toInt),
+ YarnSchedulerBackend.ENDPOINT_NAME)
+ amEndpoint =
+ rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode))
}
private def runDriver(securityMgr: SecurityManager): Unit = {
@@ -272,8 +269,8 @@ private[spark] class ApplicationMaster(
ApplicationMaster.EXIT_SC_NOT_INITED,
"Timed out waiting for SparkContext.")
} else {
- actorSystem = sc.env.actorSystem
- runAMActor(
+ rpcEnv = sc.env.rpcEnv
+ runAMEndpoint(
sc.getConf.get("spark.driver.host"),
sc.getConf.get("spark.driver.port"),
isClusterMode = true)
@@ -283,8 +280,7 @@ private[spark] class ApplicationMaster(
}
private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
- actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0,
- conf = sparkConf, securityManager = securityMgr)._1
+ rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, 0, sparkConf, securityMgr)
waitForSparkDriver()
addAmIpFilter()
registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr)
@@ -431,7 +427,7 @@ private[spark] class ApplicationMaster(
sparkConf.set("spark.driver.host", driverHost)
sparkConf.set("spark.driver.port", driverPort.toString)
- runAMActor(driverHost, driverPort.toString, isClusterMode = false)
+ runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false)
}
/** Add the Yarn IP filter that is required for properly securing the UI. */
@@ -443,7 +439,7 @@ private[spark] class ApplicationMaster(
System.setProperty("spark.ui.filters", amFilter)
params.foreach { case (k, v) => System.setProperty(s"spark.$amFilter.param.$k", v) }
} else {
- actor ! AddWebUIFilter(amFilter, params.toMap, proxyBase)
+ amEndpoint.send(AddWebUIFilter(amFilter, params.toMap, proxyBase))
}
}
@@ -505,44 +501,29 @@ private[spark] class ApplicationMaster(
}
/**
- * An actor that communicates with the driver's scheduler backend.
+ * An [[RpcEndpoint]] that communicates with the driver's scheduler backend.
*/
- private class AMActor(driverUrl: String, isClusterMode: Boolean) extends Actor {
- var driver: ActorSelection = _
-
- override def preStart(): Unit = {
- logInfo("Listen to driver: " + driverUrl)
- driver = context.actorSelection(driverUrl)
- // Send a hello message to establish the connection, after which
- // we can monitor Lifecycle Events.
- driver ! "Hello"
- driver ! RegisterClusterManager
- // In cluster mode, the AM can directly monitor the driver status instead
- // of trying to deduce it from the lifecycle of the driver's actor
- if (!isClusterMode) {
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
- }
+ private class AMEndpoint(
+ override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean)
+ extends RpcEndpoint with Logging {
+
+ override def onStart(): Unit = {
+ driver.send(RegisterClusterManager(self))
}
override def receive: PartialFunction[Any, Unit] = {
- case x: DisassociatedEvent =>
- logInfo(s"Driver terminated or disconnected! Shutting down. $x")
- // In cluster mode, do not rely on the disassociated event to exit
- // This avoids potentially reporting incorrect exit codes if the driver fails
- if (!isClusterMode) {
- finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
- }
-
case x: AddWebUIFilter =>
logInfo(s"Add WebUI Filter. $x")
- driver ! x
+ driver.send(x)
+ }
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RequestExecutors(requestedTotal) =>
Option(allocator) match {
case Some(a) => a.requestTotalExecutors(requestedTotal)
case None => logWarning("Container allocator is not ready to request executors yet.")
}
- sender ! true
+ context.reply(true)
case KillExecutors(executorIds) =>
logInfo(s"Driver requested to kill executor(s) ${executorIds.mkString(", ")}.")
@@ -550,7 +531,16 @@ private[spark] class ApplicationMaster(
case Some(a) => executorIds.foreach(a.killExecutor)
case None => logWarning("Container allocator is not ready to kill executors yet.")
}
- sender ! true
+ context.reply(true)
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress")
+ // In cluster mode, do not rely on the disassociated event to exit
+ // This avoids potentially reporting incorrect exit codes if the driver fails
+ if (!isClusterMode) {
+ finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS)
+ }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
----------------------------------------------------------------------
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index c98763e..b8f42da 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -112,7 +112,7 @@ private[yarn] class YarnAllocator(
SparkEnv.driverActorSystemName,
sparkConf.get("spark.driver.host"),
sparkConf.get("spark.driver.port"),
- CoarseGrainedSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
// For testing
private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[2/2] spark git commit: [SPARK-6602][Core] Replace direct use of Akka
with Spark RPC interface - part 1
Posted by rx...@apache.org.
[SPARK-6602][Core] Replace direct use of Akka with Spark RPC interface - part 1
This PR replaced the following `Actor`s to `RpcEndpoint`:
1. HeartbeatReceiver
1. ExecutorActor
1. BlockManagerMasterActor
1. BlockManagerSlaveActor
1. CoarseGrainedExecutorBackend and subclasses
1. CoarseGrainedSchedulerBackend.DriverActor
This is the first PR. I will split the work of SPARK-6602 to several PRs for code review.
Author: zsxwing <zs...@gmail.com>
Closes #5268 from zsxwing/rpc-rewrite and squashes the following commits:
287e9f8 [zsxwing] Fix the code style
26c56b7 [zsxwing] Merge branch 'master' into rpc-rewrite
9cc825a [zsxwing] Rmove setupThreadSafeEndpoint and add ThreadSafeRpcEndpoint
30a9036 [zsxwing] Make self return null after stopping RpcEndpointRef; fix docs and error messages
705245d [zsxwing] Fix some bugs after rebasing the changes on the master
003cf80 [zsxwing] Update CoarseGrainedExecutorBackend and CoarseGrainedSchedulerBackend to use RpcEndpoint
7d0e6dc [zsxwing] Update BlockManagerSlaveActor to use RpcEndpoint
f5d6543 [zsxwing] Update BlockManagerMaster to use RpcEndpoint
30e3f9f [zsxwing] Update ExecutorActor to use RpcEndpoint
478b443 [zsxwing] Update HeartbeatReceiver to use RpcEndpoint
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f15806a8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f15806a8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f15806a8
Branch: refs/heads/master
Commit: f15806a8f8ca34288ddb2d74b9ff1972c8374b59
Parents: 7bca62f
Author: zsxwing <zs...@gmail.com>
Authored: Sat Apr 4 11:52:05 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Apr 4 11:52:05 2015 -0700
----------------------------------------------------------------------
.../org/apache/spark/HeartbeatReceiver.scala | 66 ++-
.../scala/org/apache/spark/SparkContext.scala | 23 +-
.../main/scala/org/apache/spark/SparkEnv.scala | 13 +-
.../executor/CoarseGrainedExecutorBackend.scala | 79 +--
.../org/apache/spark/executor/Executor.scala | 18 +-
.../apache/spark/executor/ExecutorActor.scala | 41 --
.../spark/executor/ExecutorEndpoint.scala | 43 ++
.../scala/org/apache/spark/rpc/RpcEnv.scala | 39 +-
.../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 10 +-
.../apache/spark/scheduler/DAGScheduler.scala | 11 +-
.../cluster/CoarseGrainedClusterMessage.scala | 6 +-
.../cluster/CoarseGrainedSchedulerBackend.scala | 148 +++---
.../spark/scheduler/cluster/ExecutorData.scala | 8 +-
.../cluster/SimrSchedulerBackend.scala | 13 +-
.../cluster/SparkDeploySchedulerBackend.scala | 14 +-
.../cluster/YarnSchedulerBackend.scala | 93 ++--
.../mesos/CoarseMesosSchedulerBackend.scala | 4 +-
.../spark/scheduler/local/LocalBackend.scala | 48 +-
.../org/apache/spark/storage/BlockManager.scala | 22 +-
.../spark/storage/BlockManagerMaster.scala | 72 ++-
.../spark/storage/BlockManagerMasterActor.scala | 512 -------------------
.../storage/BlockManagerMasterEndpoint.scala | 509 ++++++++++++++++++
.../spark/storage/BlockManagerMessages.scala | 7 +-
.../spark/storage/BlockManagerSlaveActor.scala | 88 ----
.../storage/BlockManagerSlaveEndpoint.scala | 94 ++++
.../scala/org/apache/spark/util/Utils.scala | 10 +
.../apache/spark/HeartbeatReceiverSuite.scala | 81 +++
.../org/apache/spark/rpc/RpcEnvSuite.scala | 14 +-
.../storage/BlockManagerReplicationSuite.scala | 28 +-
.../spark/storage/BlockManagerSuite.scala | 37 +-
.../streaming/ReceivedBlockHandlerSuite.scala | 25 +-
.../spark/deploy/yarn/ApplicationMaster.scala | 86 ++--
.../spark/deploy/yarn/YarnAllocator.scala | 2 +-
33 files changed, 1169 insertions(+), 1095 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 9f8ad03..5871b8c 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -17,15 +17,15 @@
package org.apache.spark
-import scala.concurrent.duration._
-import scala.collection.mutable
+import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors}
-import akka.actor.{Actor, Cancellable}
+import scala.collection.mutable
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
-import org.apache.spark.util.ActorLogReceive
+import org.apache.spark.util.Utils
/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
@@ -51,9 +51,11 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
* Lives in the driver to receive heartbeats from executors..
*/
private[spark] class HeartbeatReceiver(sc: SparkContext)
- extends Actor with ActorLogReceive with Logging {
+ extends ThreadSafeRpcEndpoint with Logging {
+
+ override val rpcEnv: RpcEnv = sc.env.rpcEnv
- private var scheduler: TaskScheduler = null
+ private[spark] var scheduler: TaskScheduler = null
// executor ID -> timestamp of when the last heartbeat from this executor was received
private val executorLastSeen = new mutable.HashMap[String, Long]
@@ -69,34 +71,44 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000).
getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000))
- private var timeoutCheckingTask: Cancellable = null
-
- override def preStart(): Unit = {
- import context.dispatcher
- timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
- checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts)
- super.preStart()
+ private var timeoutCheckingTask: ScheduledFuture[_] = null
+
+ private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor(
+ Utils.namedThreadFactory("heartbeat-timeout-checking-thread"))
+
+ private val killExecutorThread = Executors.newSingleThreadExecutor(
+ Utils.namedThreadFactory("kill-executor-thread"))
+
+ override def onStart(): Unit = {
+ timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ Option(self).foreach(_.send(ExpireDeadHosts))
+ }
+ }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
}
-
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case ExpireDeadHosts =>
+ expireDeadHosts()
case TaskSchedulerIsSet =>
scheduler = sc.taskScheduler
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
if (scheduler != null) {
val unknownExecutor = !scheduler.executorHeartbeatReceived(
executorId, taskMetrics, blockManagerId)
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
executorLastSeen(executorId) = System.currentTimeMillis()
- sender ! response
+ context.reply(response)
} else {
// Because Executor will sleep several seconds before sending the first "Heartbeat", this
// case rarely happens. However, if it really happens, log it and ask the executor to
// register itself again.
logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet")
- sender ! HeartbeatResponse(reregisterBlockManager = true)
+ context.reply(HeartbeatResponse(reregisterBlockManager = true))
}
- case ExpireDeadHosts =>
- expireDeadHosts()
}
private def expireDeadHosts(): Unit = {
@@ -109,17 +121,25 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " +
s"timed out after ${now - lastSeenMs} ms"))
if (sc.supportDynamicAllocation) {
- sc.killExecutor(executorId)
+ // Asynchronously kill the executor to avoid blocking the current thread
+ killExecutorThread.submit(new Runnable {
+ override def run(): Unit = sc.killExecutor(executorId)
+ })
}
executorLastSeen.remove(executorId)
}
}
}
- override def postStop(): Unit = {
+ override def onStop(): Unit = {
if (timeoutCheckingTask != null) {
- timeoutCheckingTask.cancel()
+ timeoutCheckingTask.cancel(true)
}
- super.postStop()
+ timeoutCheckingThread.shutdownNow()
+ killExecutorThread.shutdownNow()
}
}
+
+object HeartbeatReceiver {
+ val ENDPOINT_NAME = "HeartbeatReceiver"
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 3b73a8a..942c597 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -32,8 +32,6 @@ import scala.collection.generic.Growable
import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
-import akka.actor.Props
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable,
@@ -48,12 +46,13 @@ import org.apache.mesos.MesosNativeLibrary
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
-import org.apache.spark.executor.TriggerThreadDump
+import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump}
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat,
FixedLengthBinaryInputFormat}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
+import org.apache.spark.rpc.RpcAddress
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
SparkDeploySchedulerBackend, SimrSchedulerBackend}
@@ -360,14 +359,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will
// retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640)
- private val heartbeatReceiver = env.actorSystem.actorOf(
- Props(new HeartbeatReceiver(this)), "HeartbeatReceiver")
+ private val heartbeatReceiver = env.rpcEnv.setupEndpoint(
+ HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this))
// Create and start the scheduler
private[spark] var (schedulerBackend, taskScheduler) =
SparkContext.createTaskScheduler(this, master)
- heartbeatReceiver ! TaskSchedulerIsSet
+ heartbeatReceiver.send(TaskSchedulerIsSet)
@volatile private[spark] var dagScheduler: DAGScheduler = _
try {
@@ -455,10 +454,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
if (executorId == SparkContext.DRIVER_IDENTIFIER) {
Some(Utils.getThreadDump())
} else {
- val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
- val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
- Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
- AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
+ val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get
+ val endpointRef = env.rpcEnv.setupEndpointRef(
+ SparkEnv.executorActorSystemName,
+ RpcAddress(host, port),
+ ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME)
+ Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump))
}
} catch {
case e: Exception =>
@@ -1418,7 +1419,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
dagScheduler = null
listenerBus.stop()
eventLogger.foreach(_.stop())
- env.actorSystem.stop(heartbeatReceiver)
+ env.rpcEnv.stop(heartbeatReceiver)
progressBar.foreach(_.stop())
taskScheduler = null
// TODO: Cache.stop()?
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 4a2ed82..55be0a5 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -295,7 +295,9 @@ object SparkEnv extends Logging {
}
}
- def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = {
+ def registerOrLookupEndpoint(
+ name: String, endpointCreator: => RpcEndpoint):
+ RpcEndpointRef = {
if (isDriver) {
logInfo("Registering " + name)
rpcEnv.setupEndpoint(name, endpointCreator)
@@ -334,12 +336,13 @@ object SparkEnv extends Logging {
new NioBlockTransferService(conf, securityManager)
}
- val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
- "BlockManagerMaster",
- new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
+ val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
+ BlockManagerMaster.DRIVER_ENDPOINT_NAME,
+ new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)),
+ conf, isDriver)
// NB: blockManager is not valid until initialize() is called later.
- val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
+ val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager,
numUsableCores)
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 900e678..8300f9f 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -21,39 +21,45 @@ import java.net.URL
import java.nio.ByteBuffer
import scala.collection.mutable
-import scala.concurrent.Await
+import scala.util.{Failure, Success}
-import akka.actor.{Actor, ActorSelection, Props}
-import akka.pattern.Patterns
-import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
-
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
+import org.apache.spark.rpc._
+import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
+import org.apache.spark.util.{SignalLogger, Utils}
private[spark] class CoarseGrainedExecutorBackend(
+ override val rpcEnv: RpcEnv,
driverUrl: String,
executorId: String,
hostPort: String,
cores: Int,
userClassPath: Seq[URL],
env: SparkEnv)
- extends Actor with ActorLogReceive with ExecutorBackend with Logging {
+ extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
var executor: Executor = null
- var driver: ActorSelection = null
+ @volatile var driver: Option[RpcEndpointRef] = None
- override def preStart() {
+ override def onStart() {
+ import scala.concurrent.ExecutionContext.Implicits.global
logInfo("Connecting to driver: " + driverUrl)
- driver = context.actorSelection(driverUrl)
- driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls)
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
+ driver = Some(ref)
+ ref.sendWithReply[RegisteredExecutor.type](
+ RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls))
+ } onComplete {
+ case Success(msg) => Utils.tryLogNonFatalError {
+ Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor
+ }
+ case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e)
+ }
}
def extractLogUrls: Map[String, String] = {
@@ -62,7 +68,7 @@ private[spark] class CoarseGrainedExecutorBackend(
.map(e => (e._1.substring(prefix.length).toLowerCase, e._2))
}
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receive: PartialFunction[Any, Unit] = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
val (hostname, _) = Utils.parseHostPort(hostPort)
@@ -92,23 +98,28 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.killTask(taskId, interruptThread)
}
- case x: DisassociatedEvent =>
- if (x.remoteAddress == driver.anchorPath.address) {
- logError(s"Driver $x disassociated! Shutting down.")
- System.exit(1)
- } else {
- logWarning(s"Received irrelevant DisassociatedEvent $x")
- }
-
case StopExecutor =>
logInfo("Driver commanded a shutdown")
executor.stop()
- context.stop(self)
- context.system.shutdown()
+ stop()
+ rpcEnv.shutdown()
+ }
+
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (driver.exists(_.address == remoteAddress)) {
+ logError(s"Driver $remoteAddress disassociated! Shutting down.")
+ System.exit(1)
+ } else {
+ logWarning(s"An unknown ($remoteAddress) driver disconnected.")
+ }
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
- driver ! StatusUpdate(executorId, taskId, state, data)
+ val msg = StatusUpdate(executorId, taskId, state, data)
+ driver match {
+ case Some(driverRef) => driverRef.send(msg)
+ case None => logWarning(s"Drop $msg because has not yet connected to driver")
+ }
}
}
@@ -132,16 +143,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Bootstrap to fetch the driver's Spark properties.
val executorConf = new SparkConf
val port = executorConf.getInt("spark.executor.port", 0)
- val (fetcher, _) = AkkaUtils.createActorSystem(
+ val fetcher = RpcEnv.create(
"driverPropsFetcher",
hostname,
port,
executorConf,
new SecurityManager(executorConf))
- val driver = fetcher.actorSelection(driverUrl)
- val timeout = AkkaUtils.askTimeout(executorConf)
- val fut = Patterns.ask(driver, RetrieveSparkProps, timeout)
- val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++
+ val driver = fetcher.setupEndpointRefByURI(driverUrl)
+ val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++
Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
@@ -162,16 +171,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val boundPort = env.conf.getInt("spark.executor.port", 0)
assert(boundPort != 0)
- // Start the CoarseGrainedExecutorBackend actor.
+ // Start the CoarseGrainedExecutorBackend endpoint.
val sparkHostPort = hostname + ":" + boundPort
- env.actorSystem.actorOf(
- Props(classOf[CoarseGrainedExecutorBackend],
- driverUrl, executorId, sparkHostPort, cores, userClassPath, env),
- name = "Executor")
+ env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
+ env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
workerUrl.foreach { url =>
env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
}
- env.actorSystem.awaitTermination()
+ env.rpcEnv.awaitTermination()
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index bf3135e..14f99a4 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -27,8 +27,6 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
-import akka.actor.Props
-
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
@@ -88,9 +86,9 @@ private[spark] class Executor(
env.blockManager.initialize(conf.getAppId)
}
- // Create an actor for receiving RPCs from the driver
- private val executorActor = env.actorSystem.actorOf(
- Props(new ExecutorActor(executorId)), "ExecutorActor")
+ // Create an RpcEndpoint for receiving RPCs from the driver
+ private val executorEndpoint = env.rpcEnv.setupEndpoint(
+ ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId))
// Whether to load classes in user jars before those in Spark jars
private val userClassPathFirst: Boolean = {
@@ -139,7 +137,7 @@ private[spark] class Executor(
def stop(): Unit = {
env.metricsSystem.report()
- env.actorSystem.stop(executorActor)
+ env.rpcEnv.stop(executorEndpoint)
isStopped = true
threadPool.shutdown()
if (!isLocal) {
@@ -391,11 +389,8 @@ private[spark] class Executor(
}
}
- private val timeout = AkkaUtils.lookupTimeout(conf)
- private val retryAttempts = AkkaUtils.numRetries(conf)
- private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
private val heartbeatReceiverRef =
- AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)
+ RpcUtils.makeDriverRef(HeartbeatReceiver.ENDPOINT_NAME, conf, env.rpcEnv)
/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
@@ -426,8 +421,7 @@ private[spark] class Executor(
val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
try {
- val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
- retryAttempts, retryIntervalMs, timeout)
+ val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message)
if (response.reregisterBlockManager) {
logWarning("Told to re-register on heartbeat")
env.blockManager.reregister()
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
deleted file mode 100644
index 3e47d13..0000000
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor
-
-import akka.actor.Actor
-import org.apache.spark.Logging
-
-import org.apache.spark.util.{Utils, ActorLogReceive}
-
-/**
- * Driver -> Executor message to trigger a thread dump.
- */
-private[spark] case object TriggerThreadDump
-
-/**
- * Actor that runs inside of executors to enable driver -> executor RPC.
- */
-private[spark]
-class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging {
-
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
- case TriggerThreadDump =>
- sender ! Utils.getThreadDump()
- }
-
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala
new file mode 100644
index 0000000..cf362f8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorEndpoint.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
+import org.apache.spark.util.Utils
+
+/**
+ * Driver -> Executor message to trigger a thread dump.
+ */
+private[spark] case object TriggerThreadDump
+
+/**
+ * [[RpcEndpoint]] that runs inside of executors to enable driver -> executor RPC.
+ */
+private[spark]
+class ExecutorEndpoint(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint {
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case TriggerThreadDump =>
+ context.reply(Utils.getThreadDump())
+ }
+
+}
+
+object ExecutorEndpoint {
+ val EXECUTOR_ENDPOINT_NAME = "ExecutorEndpoint"
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
index 7985941..d47e41a 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
@@ -40,10 +40,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
/**
* Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement
- * [[RpcEndpoint.self]].
- *
- * Note: This method won't return null. `IllegalArgumentException` will be thrown if calling this
- * on a non-existent endpoint.
+ * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist.
*/
private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef
@@ -59,20 +56,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
/**
- * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] should
- * make sure thread-safely sending messages to [[RpcEndpoint]].
- *
- * Thread-safety means processing of one message happens before processing of the next message by
- * the same [[RpcEndpoint]]. In the other words, changes to internal fields of a [[RpcEndpoint]]
- * are visible when processing the next message, and fields in the [[RpcEndpoint]] need not be
- * volatile or equivalent.
- *
- * However, there is no guarantee that the same thread will be executing the same [[RpcEndpoint]]
- * for different messages.
- */
- def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
-
- /**
* Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously.
*/
def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef]
@@ -181,7 +164,7 @@ private[spark] trait RpcEnvFactory {
* constructor onStart receive* onStop
*
* Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use
- * [[RpcEnv.setupThreadSafeEndpoint]]
+ * [[ThreadSafeRpcEndpoint]]
*
* If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be
* invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it.
@@ -195,7 +178,7 @@ private[spark] trait RpcEndpoint {
/**
* The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is
- * called.
+ * called. And `self` will become `null` when `onStop` is called.
*
* Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not
* valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called.
@@ -279,6 +262,19 @@ private[spark] trait RpcEndpoint {
}
/**
+ * A trait that requires RpcEnv thread-safely sending messages to it.
+ *
+ * Thread-safety means processing of one message happens before processing of the next message by
+ * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
+ * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
+ * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
+ *
+ * However, there is no guarantee that the same thread will be executing the same
+ * [[ThreadSafeRpcEndpoint]] for different messages.
+ */
+trait ThreadSafeRpcEndpoint extends RpcEndpoint
+
+/**
* A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe.
*/
private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf)
@@ -407,7 +403,8 @@ private[spark] object RpcAddress {
}
/**
- * A callback that [[RpcEndpoint]] can use it to send back a message or failure.
+ * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe
+ * and can be called in any thread.
*/
private[spark] trait RpcCallContext {
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
index 769d59b..9e06147 100644
--- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
@@ -82,17 +82,9 @@ private[spark] class AkkaRpcEnv private[akka] (
/**
* Retrieve the [[RpcEndpointRef]] of `endpoint`.
*/
- override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = {
- val endpointRef = endpointToRef.get(endpoint)
- require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}")
- endpointRef
- }
+ override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint)
override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
- setupThreadSafeEndpoint(name, endpoint)
- }
-
- override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
@volatile var endpointRef: AkkaRpcEndpointRef = null
// Use lazy because the Actor needs to use `endpointRef`.
// So `actorRef` should be created after assigning `endpointRef`.
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 7227fa9..917cce1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -23,14 +23,10 @@ import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
-import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.control.NonFatal
-import akka.pattern.ask
-import akka.util.Timeout
-
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
@@ -165,11 +161,8 @@ class DAGScheduler(
taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics)
blockManagerId: BlockManagerId): Boolean = {
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
- implicit val timeout = Timeout(600 seconds)
-
- Await.result(
- blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId),
- timeout.duration).asInstanceOf[Boolean]
+ blockManagerMaster.driverEndpoint.askWithReply[Boolean](
+ BlockManagerHeartbeat(blockManagerId), 600 seconds)
}
// Called by TaskScheduler when an executor fails.
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 9bf74f4..70364ce 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster
import java.nio.ByteBuffer
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.util.{SerializableBuffer, Utils}
private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
@@ -41,6 +42,7 @@ private[spark] object CoarseGrainedClusterMessages {
// Executors to driver
case class RegisterExecutor(
executorId: String,
+ executorRef: RpcEndpointRef,
hostPort: String,
cores: Int,
logUrls: Map[String, String])
@@ -70,6 +72,8 @@ private[spark] object CoarseGrainedClusterMessages {
case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
+ case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage
+
// Exchanged between the driver and the AM in Yarn client mode
case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String)
extends CoarseGrainedClusterMessage
@@ -77,7 +81,7 @@ private[spark] object CoarseGrainedClusterMessages {
// Messages exchanged between the driver and the cluster manager for executor allocation
// In Yarn mode, these are exchanged between the driver and the AM
- case object RegisterClusterManager extends CoarseGrainedClusterMessage
+ case class RegisterClusterManager(am: RpcEndpointRef) extends CoarseGrainedClusterMessage
// Request executors by specifying the new total number of executors desired
// This includes executors already pending or running
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 5d258d9..4c49da8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -17,20 +17,16 @@
package org.apache.spark.scheduler.cluster
+import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.concurrent.Await
-import scala.concurrent.duration._
-
-import akka.actor._
-import akka.pattern.ask
-import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+import org.apache.spark.rpc._
import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils}
+import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils}
/**
* A scheduler backend that waits for coarse grained executors to connect to it through Akka.
@@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut
* (spark.deploy.*).
*/
private[spark]
-class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv)
extends ExecutorAllocationClient with SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
@@ -49,7 +45,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
// Total number of executors that are currently registered
var totalRegisteredExecutors = new AtomicInteger(0)
val conf = scheduler.sc.conf
- private val timeout = AkkaUtils.askTimeout(conf)
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
// Submit tasks only after (registered resources / total expected resources)
// is equal to at least this value, that is double between 0 and 1.
@@ -71,48 +66,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
// Executors we have requested the cluster manager to kill that have not died yet
private val executorsPendingToRemove = new HashSet[String]
- class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive {
+ class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
+ extends ThreadSafeRpcEndpoint with Logging {
override protected def log = CoarseGrainedSchedulerBackend.this.log
- private val addressToExecutorId = new HashMap[Address, String]
- override def preStart() {
- // Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ private val addressToExecutorId = new HashMap[RpcAddress, String]
+
+ private val reviveThread =
+ Executors.newSingleThreadScheduledExecutor(Utils.namedThreadFactory("driver-revive-thread"))
+ override def onStart() {
// Periodically revive offers to allow delay scheduling to work
val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000)
- import context.dispatcher
- context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers)
- }
-
- def receiveWithLogging: PartialFunction[Any, Unit] = {
- case RegisterExecutor(executorId, hostPort, cores, logUrls) =>
- Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
- if (executorDataMap.contains(executorId)) {
- sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
- } else {
- logInfo("Registered executor: " + sender + " with ID " + executorId)
- sender ! RegisteredExecutor
-
- addressToExecutorId(sender.path.address) = executorId
- totalCoreCount.addAndGet(cores)
- totalRegisteredExecutors.addAndGet(1)
- val (host, _) = Utils.parseHostPort(hostPort)
- val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls)
- // This must be synchronized because variables mutated
- // in this block are read when requesting executors
- CoarseGrainedSchedulerBackend.this.synchronized {
- executorDataMap.put(executorId, data)
- if (numPendingExecutors > 0) {
- numPendingExecutors -= 1
- logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
- }
- }
- listenerBus.post(
- SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
- makeOffers()
+ reviveThread.scheduleAtFixedRate(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ Option(self).foreach(_.send(ReviveOffers))
}
+ }, 0, reviveInterval, TimeUnit.MILLISECONDS)
+ }
+ override def receive: PartialFunction[Any, Unit] = {
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
@@ -133,33 +106,58 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
case KillTask(taskId, executorId, interruptThread) =>
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
- executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread)
+ executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread))
case None =>
// Ignoring the task kill since the executor is not registered.
logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
}
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) =>
+ Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
+ if (executorDataMap.contains(executorId)) {
+ context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
+ } else {
+ logInfo("Registered executor: " + executorRef + " with ID " + executorId)
+ context.reply(RegisteredExecutor)
+
+ addressToExecutorId(executorRef.address) = executorId
+ totalCoreCount.addAndGet(cores)
+ totalRegisteredExecutors.addAndGet(1)
+ val (host, _) = Utils.parseHostPort(hostPort)
+ val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls)
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ CoarseGrainedSchedulerBackend.this.synchronized {
+ executorDataMap.put(executorId, data)
+ if (numPendingExecutors > 0) {
+ numPendingExecutors -= 1
+ logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
+ }
+ }
+ listenerBus.post(
+ SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
+ makeOffers()
+ }
case StopDriver =>
- sender ! true
- context.stop(self)
+ context.reply(true)
+ stop()
case StopExecutors =>
logInfo("Asking each executor to shut down")
for ((_, executorData) <- executorDataMap) {
- executorData.executorActor ! StopExecutor
+ executorData.executorEndpoint.send(StopExecutor)
}
- sender ! true
+ context.reply(true)
case RemoveExecutor(executorId, reason) =>
removeExecutor(executorId, reason)
- sender ! true
-
- case DisassociatedEvent(_, address, _) =>
- addressToExecutorId.get(address).foreach(removeExecutor(_,
- "remote Akka client disassociated"))
+ context.reply(true)
case RetrieveSparkProps =>
- sender ! sparkProperties
+ context.reply(sparkProperties)
}
// Make fake resource offers on all executors
@@ -169,6 +167,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
}.toSeq))
}
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_,
+ "remote Rpc client disassociated"))
+ }
+
// Make fake resource offers on just one executor
def makeOffers(executorId: String) {
val executorData = executorDataMap(executorId)
@@ -199,7 +202,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
else {
val executorData = executorDataMap(task.executorId)
executorData.freeCores -= scheduler.CPUS_PER_TASK
- executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask))
+ executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
}
}
}
@@ -223,9 +226,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
case None => logError(s"Asked to remove non-existent executor $executorId")
}
}
+
+ override def onStop() {
+ reviveThread.shutdownNow()
+ }
}
- var driverActor: ActorRef = null
+ var driverEndpoint: RpcEndpointRef = null
val taskIdsOnSlave = new HashMap[String, HashSet[String]]
override def start() {
@@ -236,16 +243,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
}
}
// TODO (prashant) send conf instead of properties
- driverActor = actorSystem.actorOf(
- Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
+ driverEndpoint = rpcEnv.setupEndpoint(
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties))
}
def stopExecutors() {
try {
- if (driverActor != null) {
+ if (driverEndpoint != null) {
logInfo("Shutting down all executors")
- val future = driverActor.ask(StopExecutors)(timeout)
- Await.ready(future, timeout)
+ driverEndpoint.askWithReply[Boolean](StopExecutors)
}
} catch {
case e: Exception =>
@@ -256,22 +262,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
override def stop() {
stopExecutors()
try {
- if (driverActor != null) {
- val future = driverActor.ask(StopDriver)(timeout)
- Await.ready(future, timeout)
+ if (driverEndpoint != null) {
+ driverEndpoint.askWithReply[Boolean](StopDriver)
}
} catch {
case e: Exception =>
- throw new SparkException("Error stopping standalone scheduler's driver actor", e)
+ throw new SparkException("Error stopping standalone scheduler's driver endpoint", e)
}
}
override def reviveOffers() {
- driverActor ! ReviveOffers
+ driverEndpoint.send(ReviveOffers)
}
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
- driverActor ! KillTask(taskId, executorId, interruptThread)
+ driverEndpoint.send(KillTask(taskId, executorId, interruptThread))
}
override def defaultParallelism(): Int = {
@@ -281,11 +286,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
// Called by subclasses when notified of a lost worker
def removeExecutor(executorId: String, reason: String) {
try {
- val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
- Await.ready(future, timeout)
+ driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason))
} catch {
case e: Exception =>
- throw new SparkException("Error notifying standalone scheduler's driver actor", e)
+ throw new SparkException("Error notifying standalone scheduler's driver endpoint", e)
}
}
@@ -391,5 +395,5 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
}
private[spark] object CoarseGrainedSchedulerBackend {
- val ACTOR_NAME = "CoarseGrainedScheduler"
+ val ENDPOINT_NAME = "CoarseGrainedScheduler"
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
index 5e571ef..26e72c0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
@@ -17,20 +17,20 @@
package org.apache.spark.scheduler.cluster
-import akka.actor.{Address, ActorRef}
+import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress}
/**
* Grouping of data for an executor used by CoarseGrainedSchedulerBackend.
*
- * @param executorActor The ActorRef representing this executor
+ * @param executorEndpoint The ActorRef representing this executor
* @param executorAddress The network address of this executor
* @param executorHost The hostname that this executor is running on
* @param freeCores The current number of cores available for work on the executor
* @param totalCores The total number of cores available to the executor
*/
private[cluster] class ExecutorData(
- val executorActor: ActorRef,
- val executorAddress: Address,
+ val executorEndpoint: RpcEndpointRef,
+ val executorAddress: RpcAddress,
override val executorHost: String,
var freeCores: Int,
override val totalCores: Int,
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index 06786a5..0324c9d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -19,16 +19,16 @@ package org.apache.spark.scheduler.cluster
import org.apache.hadoop.fs.{Path, FileSystem}
+import org.apache.spark.rpc.RpcAddress
import org.apache.spark.{Logging, SparkContext, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.TaskSchedulerImpl
-import org.apache.spark.util.AkkaUtils
private[spark] class SimrSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
driverFilePath: String)
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
with Logging {
val tmpPath = new Path(driverFilePath + "_tmp")
@@ -39,12 +39,9 @@ private[spark] class SimrSchedulerBackend(
override def start() {
super.start()
- val driverUrl = AkkaUtils.address(
- AkkaUtils.protocol(actorSystem),
- SparkEnv.driverActorSystemName,
- sc.conf.get("spark.driver.host"),
- sc.conf.get("spark.driver.port"),
- CoarseGrainedSchedulerBackend.ACTOR_NAME)
+ val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName,
+ RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt),
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
val conf = SparkHadoopUtil.get.newConfiguration(sc.conf)
val fs = FileSystem.get(conf)
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index ffd4825..7eb3fdc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -19,17 +19,18 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.Semaphore
+import org.apache.spark.rpc.RpcAddress
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
masters: Array[String])
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
with AppClientListener
with Logging {
@@ -48,12 +49,9 @@ private[spark] class SparkDeploySchedulerBackend(
super.start()
// The endpoint for executors to talk to us
- val driverUrl = AkkaUtils.address(
- AkkaUtils.protocol(actorSystem),
- SparkEnv.driverActorSystemName,
- conf.get("spark.driver.host"),
- conf.get("spark.driver.port"),
- CoarseGrainedSchedulerBackend.ACTOR_NAME)
+ val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName,
+ RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt),
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
val args = Seq(
"--driver-url", driverUrl,
"--executor-id", "{{EXECUTOR_ID}}",
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 5a38ad9..f72566c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -19,10 +19,8 @@ package org.apache.spark.scheduler.cluster
import scala.concurrent.{Future, ExecutionContext}
-import akka.actor.{Actor, ActorRef, Props}
-import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-
-import org.apache.spark.SparkContext
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.rpc._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.ui.JettyUtils
@@ -37,7 +35,7 @@ import scala.util.control.NonFatal
private[spark] abstract class YarnSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext)
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) {
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) {
if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
minRegisteredRatio = 0.8
@@ -45,10 +43,8 @@ private[spark] abstract class YarnSchedulerBackend(
protected var totalExpectedExecutors = 0
- private val yarnSchedulerActor: ActorRef =
- actorSystem.actorOf(
- Props(new YarnSchedulerActor),
- name = YarnSchedulerBackend.ACTOR_NAME)
+ private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint(
+ YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv))
private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf)
@@ -57,16 +53,14 @@ private[spark] abstract class YarnSchedulerBackend(
* This includes executors already pending or running.
*/
override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
- AkkaUtils.askWithReply[Boolean](
- RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout)
+ yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal))
}
/**
* Request that the ApplicationMaster kill the specified executors.
*/
override def doKillExecutors(executorIds: Seq[String]): Boolean = {
- AkkaUtils.askWithReply[Boolean](
- KillExecutors(executorIds), yarnSchedulerActor, askTimeout)
+ yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds))
}
override def sufficientResourcesRegistered(): Boolean = {
@@ -96,64 +90,71 @@ private[spark] abstract class YarnSchedulerBackend(
}
/**
- * An actor that communicates with the ApplicationMaster.
+ * An [[RpcEndpoint]] that communicates with the ApplicationMaster.
*/
- private class YarnSchedulerActor extends Actor {
- private var amActor: Option[ActorRef] = None
-
- implicit val askAmActorExecutor = ExecutionContext.fromExecutor(
- Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor"))
+ private class YarnSchedulerEndpoint(override val rpcEnv: RpcEnv)
+ extends ThreadSafeRpcEndpoint with Logging {
+ private var amEndpoint: Option[RpcEndpointRef] = None
- override def preStart(): Unit = {
- // Listen for disassociation events
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
- }
+ private val askAmThreadPool =
+ Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool")
+ implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool)
override def receive: PartialFunction[Any, Unit] = {
- case RegisterClusterManager =>
- logInfo(s"ApplicationMaster registered as $sender")
- amActor = Some(sender)
+ case RegisterClusterManager(am) =>
+ logInfo(s"ApplicationMaster registered as $am")
+ amEndpoint = Some(am)
+
+ case AddWebUIFilter(filterName, filterParams, proxyBase) =>
+ addWebUIFilter(filterName, filterParams, proxyBase)
+
+ }
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case r: RequestExecutors =>
- amActor match {
- case Some(actor) =>
- val driverActor = sender
+ amEndpoint match {
+ case Some(am) =>
Future {
- driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ context.reply(am.askWithReply[Boolean](r))
} onFailure {
- case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e)
+ case NonFatal(e) =>
+ logError(s"Sending $r to AM was unsuccessful", e)
+ context.sendFailure(e)
}
case None =>
logWarning("Attempted to request executors before the AM has registered!")
- sender ! false
+ context.reply(false)
}
case k: KillExecutors =>
- amActor match {
- case Some(actor) =>
- val driverActor = sender
+ amEndpoint match {
+ case Some(am) =>
Future {
- driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ context.reply(am.askWithReply[Boolean](k))
} onFailure {
- case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e)
+ case NonFatal(e) =>
+ logError(s"Sending $k to AM was unsuccessful", e)
+ context.sendFailure(e)
}
case None =>
logWarning("Attempted to kill executors before the AM has registered!")
- sender ! false
+ context.reply(false)
}
- case AddWebUIFilter(filterName, filterParams, proxyBase) =>
- addWebUIFilter(filterName, filterParams, proxyBase)
- sender ! true
+ }
- case d: DisassociatedEvent =>
- if (amActor.isDefined && sender == amActor.get) {
- logWarning(s"ApplicationMaster has disassociated: $d")
- }
+ override def onDisconnected(remoteAddress: RpcAddress): Unit = {
+ if (amEndpoint.exists(_.address == remoteAddress)) {
+ logWarning(s"ApplicationMaster has disassociated: $remoteAddress")
+ }
+ }
+
+ override def onStop(): Unit ={
+ askAmThreadPool.shutdownNow()
}
}
}
private[spark] object YarnSchedulerBackend {
- val ACTOR_NAME = "YarnScheduler"
+ val ENDPOINT_NAME = "YarnScheduler"
}
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index e13de0f..b037a49 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
master: String)
- extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
with MScheduler
with Logging {
@@ -148,7 +148,7 @@ private[spark] class CoarseMesosSchedulerBackend(
SparkEnv.driverActorSystemName,
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
- CoarseGrainedSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ENDPOINT_NAME)
val uri = conf.get("spark.executor.uri", null)
if (uri == null) {
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index eb3f999..70a477a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -18,17 +18,14 @@
package org.apache.spark.scheduler.local
import java.nio.ByteBuffer
+import java.util.concurrent.{Executors, TimeUnit}
-import scala.concurrent.duration._
-import scala.language.postfixOps
-
-import akka.actor.{Actor, ActorRef, Props}
-
+import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
-import org.apache.spark.util.ActorLogReceive
private case class ReviveOffers()
@@ -39,17 +36,19 @@ private case class KillTask(taskId: Long, interruptThread: Boolean)
private case class StopExecutor()
/**
- * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
- * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
+ * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the
+ * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
* and the TaskSchedulerImpl.
*/
-private[spark] class LocalActor(
+private[spark] class LocalEndpoint(
+ override val rpcEnv: RpcEnv,
scheduler: TaskSchedulerImpl,
executorBackend: LocalBackend,
private val totalCores: Int)
- extends Actor with ActorLogReceive with Logging {
+ extends ThreadSafeRpcEndpoint with Logging {
- import context.dispatcher // to use Akka's scheduler.scheduleOnce()
+ private val reviveThread = Executors.newSingleThreadScheduledExecutor(
+ Utils.namedThreadFactory("local-revive-thread"))
private var freeCores = totalCores
@@ -59,7 +58,7 @@ private[spark] class LocalActor(
private val executor = new Executor(
localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true)
- override def receiveWithLogging: PartialFunction[Any, Unit] = {
+ override def receive: PartialFunction[Any, Unit] = {
case ReviveOffers =>
reviveOffers()
@@ -87,9 +86,17 @@ private[spark] class LocalActor(
}
if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) {
// Try to reviveOffer after 1 second, because scheduler may wait for locality timeout
- context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers)
+ reviveThread.schedule(new Runnable {
+ override def run(): Unit = Utils.tryLogNonFatalError {
+ Option(self).foreach(_.send(ReviveOffers))
+ }
+ }, 1000, TimeUnit.MILLISECONDS)
}
}
+
+ override def onStop(): Unit = {
+ reviveThread.shutdownNow()
+ }
}
/**
@@ -101,31 +108,30 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores:
extends SchedulerBackend with ExecutorBackend {
private val appId = "local-" + System.currentTimeMillis
- var localActor: ActorRef = null
+ var localEndpoint: RpcEndpointRef = null
override def start() {
- localActor = SparkEnv.get.actorSystem.actorOf(
- Props(new LocalActor(scheduler, this, totalCores)),
- "LocalBackendActor")
+ localEndpoint = SparkEnv.get.rpcEnv.setupEndpoint(
+ "LocalBackendEndpoint", new LocalEndpoint(SparkEnv.get.rpcEnv, scheduler, this, totalCores))
}
override def stop() {
- localActor ! StopExecutor
+ localEndpoint.send(StopExecutor)
}
override def reviveOffers() {
- localActor ! ReviveOffers
+ localEndpoint.send(ReviveOffers)
}
override def defaultParallelism(): Int =
scheduler.conf.getInt("spark.default.parallelism", totalCores)
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) {
- localActor ! KillTask(taskId, interruptThread)
+ localEndpoint.send(KillTask(taskId, interruptThread))
}
override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
- localActor ! StatusUpdate(taskId, state, serializedData)
+ localEndpoint.send(StatusUpdate(taskId, state, serializedData))
}
override def applicationId(): String = appId
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index fc31296..1aa0ef1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -26,7 +26,6 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.util.Random
-import akka.actor.{ActorSystem, Props}
import sun.nio.ch.DirectBuffer
import org.apache.spark._
@@ -37,6 +36,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+import org.apache.spark.rpc.RpcEnv
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.hash.HashShuffleManager
@@ -64,7 +64,7 @@ private[spark] class BlockResult(
*/
private[spark] class BlockManager(
executorId: String,
- actorSystem: ActorSystem,
+ rpcEnv: RpcEnv,
val master: BlockManagerMaster,
defaultSerializer: Serializer,
maxMemory: Long,
@@ -136,9 +136,9 @@ private[spark] class BlockManager(
// Whether to compress shuffle output temporarily spilled to disk
private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
- private val slaveActor = actorSystem.actorOf(
- Props(new BlockManagerSlaveActor(this, mapOutputTracker)),
- name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
+ private val slaveEndpoint = rpcEnv.setupEndpoint(
+ "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next,
+ new BlockManagerSlaveEndpoint(rpcEnv, this, mapOutputTracker))
// Pending re-registration action being executed asynchronously or null if none is pending.
// Accesses should synchronize on asyncReregisterLock.
@@ -167,7 +167,7 @@ private[spark] class BlockManager(
*/
def this(
execId: String,
- actorSystem: ActorSystem,
+ rpcEnv: RpcEnv,
master: BlockManagerMaster,
serializer: Serializer,
conf: SparkConf,
@@ -176,7 +176,7 @@ private[spark] class BlockManager(
blockTransferService: BlockTransferService,
securityManager: SecurityManager,
numUsableCores: Int) = {
- this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
+ this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf),
conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores)
}
@@ -186,7 +186,7 @@ private[spark] class BlockManager(
* where it is only learned after registration with the TaskScheduler).
*
* This method initializes the BlockTransferService and ShuffleClient, registers with the
- * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+ * BlockManagerMaster, starts the BlockManagerWorker endpoint, and registers with a local shuffle
* service if configured.
*/
def initialize(appId: String): Unit = {
@@ -202,7 +202,7 @@ private[spark] class BlockManager(
blockManagerId
}
- master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+ master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
// Register Executors' configuration with the local shuffle service, if one should exist.
if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
@@ -265,7 +265,7 @@ private[spark] class BlockManager(
def reregister(): Unit = {
// TODO: We might need to rate limit re-registering.
logInfo("BlockManager re-registering with master")
- master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+ master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
reportAllBlocks()
}
@@ -1215,7 +1215,7 @@ private[spark] class BlockManager(
shuffleClient.close()
}
diskBlockManager.stop()
- actorSystem.stop(slaveActor)
+ rpcEnv.stop(slaveEndpoint)
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
http://git-wip-us.apache.org/repos/asf/spark/blob/f15806a8/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 0619648..ceacf04 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -20,35 +20,31 @@ package org.apache.spark.storage
import scala.concurrent.{Await, Future}
import scala.concurrent.ExecutionContext.Implicits.global
-import akka.actor._
-
+import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.AkkaUtils
private[spark]
class BlockManagerMaster(
- var driverActor: ActorRef,
+ var driverEndpoint: RpcEndpointRef,
conf: SparkConf,
isDriver: Boolean)
extends Logging {
- private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf)
- private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf)
-
- val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
val timeout = AkkaUtils.askTimeout(conf)
- /** Remove a dead executor from the driver actor. This is only called on the driver side. */
+ /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */
def removeExecutor(execId: String) {
tell(RemoveExecutor(execId))
logInfo("Removed " + execId + " successfully in removeExecutor")
}
/** Register the BlockManager's id with the driver. */
- def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ def registerBlockManager(
+ blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = {
logInfo("Trying to register BlockManager")
- tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
+ tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
logInfo("Registered BlockManager")
}
@@ -59,7 +55,7 @@ class BlockManagerMaster(
memSize: Long,
diskSize: Long,
tachyonSize: Long): Boolean = {
- val res = askDriverWithReply[Boolean](
+ val res = driverEndpoint.askWithReply[Boolean](
UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize))
logDebug(s"Updated info of block $blockId")
res
@@ -67,12 +63,12 @@ class BlockManagerMaster(
/** Get locations of the blockId from the driver */
def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
- askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
+ driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId))
}
/** Get locations of multiple blockIds from the driver */
def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
- askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
+ driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
/**
@@ -85,11 +81,11 @@ class BlockManagerMaster(
/** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
- askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
+ driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
- def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
- askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
+ def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId))
}
/**
@@ -97,12 +93,12 @@ class BlockManagerMaster(
* blocks that the driver knows about.
*/
def removeBlock(blockId: BlockId) {
- askDriverWithReply(RemoveBlock(blockId))
+ driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId))
}
/** Remove all blocks belonging to the given RDD. */
def removeRdd(rddId: Int, blocking: Boolean) {
- val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
+ val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}")
@@ -114,7 +110,7 @@ class BlockManagerMaster(
/** Remove all blocks belonging to the given shuffle. */
def removeShuffle(shuffleId: Int, blocking: Boolean) {
- val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
+ val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}")
@@ -126,7 +122,7 @@ class BlockManagerMaster(
/** Remove all blocks belonging to the given broadcast. */
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
- val future = askDriverWithReply[Future[Seq[Int]]](
+ val future = driverEndpoint.askWithReply[Future[Seq[Int]]](
RemoveBroadcast(broadcastId, removeFromMaster))
future.onFailure {
case e: Exception =>
@@ -145,11 +141,11 @@ class BlockManagerMaster(
* amount of remaining memory.
*/
def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
- askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
+ driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
def getStorageStatus: Array[StorageStatus] = {
- askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
+ driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus)
}
/**
@@ -165,11 +161,12 @@ class BlockManagerMaster(
askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = {
val msg = GetBlockStatus(blockId, askSlaves)
/*
- * To avoid potential deadlocks, the use of Futures is necessary, because the master actor
+ * To avoid potential deadlocks, the use of Futures is necessary, because the master endpoint
* should not block on waiting for a block manager, which can in turn be waiting for the
- * master actor for a response to a prior message.
+ * master endpoint for a response to a prior message.
*/
- val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
+ val response = driverEndpoint.
+ askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
val (blockManagerIds, futures) = response.unzip
val result = Await.result(Future.sequence(futures), timeout)
if (result == null) {
@@ -193,33 +190,28 @@ class BlockManagerMaster(
filter: BlockId => Boolean,
askSlaves: Boolean): Seq[BlockId] = {
val msg = GetMatchingBlockIds(filter, askSlaves)
- val future = askDriverWithReply[Future[Seq[BlockId]]](msg)
+ val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg)
Await.result(future, timeout)
}
- /** Stop the driver actor, called only on the Spark driver node */
+ /** Stop the driver endpoint, called only on the Spark driver node */
def stop() {
- if (driverActor != null && isDriver) {
+ if (driverEndpoint != null && isDriver) {
tell(StopBlockManagerMaster)
- driverActor = null
+ driverEndpoint = null
logInfo("BlockManagerMaster stopped")
}
}
- /** Send a one-way message to the master actor, to which we expect it to reply with true. */
+ /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */
private def tell(message: Any) {
- if (!askDriverWithReply[Boolean](message)) {
- throw new SparkException("BlockManagerMasterActor returned false, expected true.")
+ if (!driverEndpoint.askWithReply[Boolean](message)) {
+ throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.")
}
}
- /**
- * Send a message to the driver actor and get its result within a default timeout, or
- * throw a SparkException if this fails.
- */
- private def askDriverWithReply[T](message: Any): T = {
- AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS,
- timeout)
- }
+}
+private[spark] object BlockManagerMaster {
+ val DRIVER_ENDPOINT_NAME = "BlockManagerMaster"
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org