You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2015/07/28 02:59:47 UTC

spark git commit: [SPARK-8882] [STREAMING] Add a new Receiver scheduling mechanism

Repository: spark
Updated Branches:
  refs/heads/master ce89ff477 -> daa1964b6


[SPARK-8882] [STREAMING] Add a new Receiver scheduling mechanism

The design doc: https://docs.google.com/document/d/1ZsoRvHjpISPrDmSjsGzuSu8UjwgbtmoCTzmhgTurHJw/edit?usp=sharing

Author: zsxwing <zs...@gmail.com>

Closes #7276 from zsxwing/receiver-scheduling and squashes the following commits:

137b257 [zsxwing] Add preferredNumExecutors to rescheduleReceiver
61a6c3f [zsxwing] Set state to ReceiverState.INACTIVE in deregisterReceiver
5e1fa48 [zsxwing] Fix the code style
7451498 [zsxwing] Move DummyReceiver back to ReceiverTrackerSuite
715ef9c [zsxwing] Rename: scheduledLocations -> scheduledExecutors; locations -> executors
05daf9c [zsxwing] Use receiverTrackingInfo.toReceiverInfo
1d6d7c8 [zsxwing] Merge branch 'master' into receiver-scheduling
8f93c8d [zsxwing] Use hostPort as the receiver location rather than host; fix comments and unit tests
59f8887 [zsxwing] Schedule all receivers at the same time when launching them
075e0a3 [zsxwing] Add receiver RDD name; use '!isTrackerStarted' instead
276a4ac [zsxwing] Remove "ReceiverLauncher" and move codes to "launchReceivers"
fab9a01 [zsxwing] Move methods back to the outer class
4e639c4 [zsxwing] Fix unintentional changes
f60d021 [zsxwing] Reorganize ReceiverTracker to use an event loop for lock free
105037e [zsxwing] Merge branch 'master' into receiver-scheduling
5fee132 [zsxwing] Update tha scheduling algorithm to avoid to keep restarting Receiver
9e242c8 [zsxwing] Remove the ScheduleReceiver message because we can refuse it when receiving RegisterReceiver
a9acfbf [zsxwing] Merge branch 'squash-pr-6294' into receiver-scheduling
881edb9 [zsxwing] ReceiverScheduler -> ReceiverSchedulingPolicy
e530bcc [zsxwing] [SPARK-5681][Streaming] Use a lock to eliminate the race condition when stopping receivers and registering receivers happen at the same time #6294
3b87e4a [zsxwing] Revert SparkContext.scala
a86850c [zsxwing] Remove submitAsyncJob and revert JobWaiter
f549595 [zsxwing] Add comments for the scheduling approach
9ecc08e [zsxwing] Fix comments and code style
28d1bee [zsxwing] Make 'host' protected; rescheduleReceiver -> getAllowedLocations
2c86a9e [zsxwing] Use tryFailure to support calling jobFailed multiple times
ca6fe35 [zsxwing] Add a test for Receiver.restart
27acd45 [zsxwing] Add unit tests for LoadBalanceReceiverSchedulerImplSuite
cc76142 [zsxwing] Add JobWaiter.toFuture to avoid blocking threads
d9a3e72 [zsxwing] Add a new Receiver scheduling mechanism


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/daa1964b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/daa1964b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/daa1964b

Branch: refs/heads/master
Commit: daa1964b6098f79100def78451bda181b5c92198
Parents: ce89ff4
Author: zsxwing <zs...@gmail.com>
Authored: Mon Jul 27 17:59:43 2015 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Mon Jul 27 17:59:43 2015 -0700

----------------------------------------------------------------------
 .../streaming/receiver/ReceiverSupervisor.scala |   4 +-
 .../receiver/ReceiverSupervisorImpl.scala       |   6 +-
 .../streaming/scheduler/ReceiverInfo.scala      |   1 -
 .../scheduler/ReceiverSchedulingPolicy.scala    | 171 +++++++
 .../streaming/scheduler/ReceiverTracker.scala   | 468 ++++++++++++-------
 .../scheduler/ReceiverTrackingInfo.scala        |  55 +++
 .../ReceiverSchedulingPolicySuite.scala         | 130 ++++++
 .../scheduler/ReceiverTrackerSuite.scala        |  66 +--
 .../ui/StreamingJobProgressListenerSuite.scala  |   6 +-
 9 files changed, 674 insertions(+), 233 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/daa1964b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
index a7c220f..e98017a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer
 import scala.concurrent._
 import scala.util.control.NonFatal
 
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SparkEnv, Logging, SparkConf}
 import org.apache.spark.storage.StreamBlockId
-import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.{Utils, ThreadUtils}
 
 /**
  * Abstract class that is responsible for supervising a Receiver in the worker.

http://git-wip-us.apache.org/repos/asf/spark/blob/daa1964b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index 2f6841e..0d802f8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -30,7 +30,7 @@ import org.apache.spark.storage.StreamBlockId
 import org.apache.spark.streaming.Time
 import org.apache.spark.streaming.scheduler._
 import org.apache.spark.streaming.util.WriteAheadLogUtils
-import org.apache.spark.util.{RpcUtils, Utils}
+import org.apache.spark.util.RpcUtils
 import org.apache.spark.{Logging, SparkEnv, SparkException}
 
 /**
@@ -46,6 +46,8 @@ private[streaming] class ReceiverSupervisorImpl(
     checkpointDirOption: Option[String]
   ) extends ReceiverSupervisor(receiver, env.conf) with Logging {
 
+  private val hostPort = SparkEnv.get.blockManager.blockManagerId.hostPort
+
   private val receivedBlockHandler: ReceivedBlockHandler = {
     if (WriteAheadLogUtils.enableReceiverLog(env.conf)) {
       if (checkpointDirOption.isEmpty) {
@@ -170,7 +172,7 @@ private[streaming] class ReceiverSupervisorImpl(
 
   override protected def onReceiverStart(): Boolean = {
     val msg = RegisterReceiver(
-      streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint)
+      streamId, receiver.getClass.getSimpleName, hostPort, endpoint)
     trackerEndpoint.askWithRetry[Boolean](msg)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/daa1964b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala
index de85f24..59df892 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala
@@ -28,7 +28,6 @@ import org.apache.spark.rpc.RpcEndpointRef
 case class ReceiverInfo(
     streamId: Int,
     name: String,
-    private[streaming] val endpoint: RpcEndpointRef,
     active: Boolean,
     location: String,
     lastErrorMessage: String = "",

http://git-wip-us.apache.org/repos/asf/spark/blob/daa1964b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala
new file mode 100644
index 0000000..ef5b687
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.Map
+import scala.collection.mutable
+
+import org.apache.spark.streaming.receiver.Receiver
+
+private[streaming] class ReceiverSchedulingPolicy {
+
+  /**
+   * Try our best to schedule receivers with evenly distributed. However, if the
+   * `preferredLocation`s of receivers are not even, we may not be able to schedule them evenly
+   * because we have to respect them.
+   *
+   * Here is the approach to schedule executors:
+   * <ol>
+   *   <li>First, schedule all the receivers with preferred locations (hosts), evenly among the
+   *       executors running on those host.</li>
+   *   <li>Then, schedule all other receivers evenly among all the executors such that overall
+   *       distribution over all the receivers is even.</li>
+   * </ol>
+   *
+   * This method is called when we start to launch receivers at the first time.
+   */
+  def scheduleReceivers(
+      receivers: Seq[Receiver[_]], executors: Seq[String]): Map[Int, Seq[String]] = {
+    if (receivers.isEmpty) {
+      return Map.empty
+    }
+
+    if (executors.isEmpty) {
+      return receivers.map(_.streamId -> Seq.empty).toMap
+    }
+
+    val hostToExecutors = executors.groupBy(_.split(":")(0))
+    val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String])
+    val numReceiversOnExecutor = mutable.HashMap[String, Int]()
+    // Set the initial value to 0
+    executors.foreach(e => numReceiversOnExecutor(e) = 0)
+
+    // Firstly, we need to respect "preferredLocation". So if a receiver has "preferredLocation",
+    // we need to make sure the "preferredLocation" is in the candidate scheduled executor list.
+    for (i <- 0 until receivers.length) {
+      // Note: preferredLocation is host but executors are host:port
+      receivers(i).preferredLocation.foreach { host =>
+        hostToExecutors.get(host) match {
+          case Some(executorsOnHost) =>
+            // preferredLocation is a known host. Select an executor that has the least receivers in
+            // this host
+            val leastScheduledExecutor =
+              executorsOnHost.minBy(executor => numReceiversOnExecutor(executor))
+            scheduledExecutors(i) += leastScheduledExecutor
+            numReceiversOnExecutor(leastScheduledExecutor) =
+              numReceiversOnExecutor(leastScheduledExecutor) + 1
+          case None =>
+            // preferredLocation is an unknown host.
+            // Note: There are two cases:
+            // 1. This executor is not up. But it may be up later.
+            // 2. This executor is dead, or it's not a host in the cluster.
+            // Currently, simply add host to the scheduled executors.
+            scheduledExecutors(i) += host
+        }
+      }
+    }
+
+    // For those receivers that don't have preferredLocation, make sure we assign at least one
+    // executor to them.
+    for (scheduledExecutorsForOneReceiver <- scheduledExecutors.filter(_.isEmpty)) {
+      // Select the executor that has the least receivers
+      val (leastScheduledExecutor, numReceivers) = numReceiversOnExecutor.minBy(_._2)
+      scheduledExecutorsForOneReceiver += leastScheduledExecutor
+      numReceiversOnExecutor(leastScheduledExecutor) = numReceivers + 1
+    }
+
+    // Assign idle executors to receivers that have less executors
+    val idleExecutors = numReceiversOnExecutor.filter(_._2 == 0).map(_._1)
+    for (executor <- idleExecutors) {
+      // Assign an idle executor to the receiver that has least candidate executors.
+      val leastScheduledExecutors = scheduledExecutors.minBy(_.size)
+      leastScheduledExecutors += executor
+    }
+
+    receivers.map(_.streamId).zip(scheduledExecutors).toMap
+  }
+
+  /**
+   * Return a list of candidate executors to run the receiver. If the list is empty, the caller can
+   * run this receiver in arbitrary executor. The caller can use `preferredNumExecutors` to require
+   * returning `preferredNumExecutors` executors if possible.
+   *
+   * This method tries to balance executors' load. Here is the approach to schedule executors
+   * for a receiver.
+   * <ol>
+   *   <li>
+   *     If preferredLocation is set, preferredLocation should be one of the candidate executors.
+   *   </li>
+   *   <li>
+   *     Every executor will be assigned to a weight according to the receivers running or
+   *     scheduling on it.
+   *     <ul>
+   *       <li>
+   *         If a receiver is running on an executor, it contributes 1.0 to the executor's weight.
+   *       </li>
+   *       <li>
+   *         If a receiver is scheduled to an executor but has not yet run, it contributes
+   *         `1.0 / #candidate_executors_of_this_receiver` to the executor's weight.</li>
+   *     </ul>
+   *     At last, if there are more than `preferredNumExecutors` idle executors (weight = 0),
+   *     returns all idle executors. Otherwise, we only return `preferredNumExecutors` best options
+   *     according to the weights.
+   *   </li>
+   * </ol>
+   *
+   * This method is called when a receiver is registering with ReceiverTracker or is restarting.
+   */
+  def rescheduleReceiver(
+      receiverId: Int,
+      preferredLocation: Option[String],
+      receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo],
+      executors: Seq[String],
+      preferredNumExecutors: Int = 3): Seq[String] = {
+    if (executors.isEmpty) {
+      return Seq.empty
+    }
+
+    // Always try to schedule to the preferred locations
+    val scheduledExecutors = mutable.Set[String]()
+    scheduledExecutors ++= preferredLocation
+
+    val executorWeights = receiverTrackingInfoMap.values.flatMap { receiverTrackingInfo =>
+      receiverTrackingInfo.state match {
+        case ReceiverState.INACTIVE => Nil
+        case ReceiverState.SCHEDULED =>
+          val scheduledExecutors = receiverTrackingInfo.scheduledExecutors.get
+          // The probability that a scheduled receiver will run in an executor is
+          // 1.0 / scheduledLocations.size
+          scheduledExecutors.map(location => location -> (1.0 / scheduledExecutors.size))
+        case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0)
+      }
+    }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor
+
+    val idleExecutors = (executors.toSet -- executorWeights.keys).toSeq
+    if (idleExecutors.size >= preferredNumExecutors) {
+      // If there are more than `preferredNumExecutors` idle executors, return all of them
+      scheduledExecutors ++= idleExecutors
+    } else {
+      // If there are less than `preferredNumExecutors` idle executors, return 3 best options
+      scheduledExecutors ++= idleExecutors
+      val sortedExecutors = executorWeights.toSeq.sortBy(_._2).map(_._1)
+      scheduledExecutors ++= (idleExecutors ++ sortedExecutors).take(preferredNumExecutors)
+    }
+    scheduledExecutors.toSeq
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/daa1964b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 9cc6ffc..6270137 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -17,17 +17,27 @@
 
 package org.apache.spark.streaming.scheduler
 
-import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap}
+import java.util.concurrent.{TimeUnit, CountDownLatch}
+
+import scala.collection.mutable.HashMap
+import scala.concurrent.ExecutionContext
 import scala.language.existentials
-import scala.math.max
+import scala.util.{Failure, Success}
 
 import org.apache.spark.streaming.util.WriteAheadLogUtils
-import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
 import org.apache.spark.rpc._
 import org.apache.spark.streaming.{StreamingContext, Time}
-import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl,
-  StopReceiver, UpdateRateLimit}
-import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.streaming.receiver._
+import org.apache.spark.util.{ThreadUtils, SerializableConfiguration}
+
+
+/** Enumeration to identify current state of a Receiver */
+private[streaming] object ReceiverState extends Enumeration {
+  type ReceiverState = Value
+  val INACTIVE, SCHEDULED, ACTIVE = Value
+}
 
 /**
  * Messages used by the NetworkReceiver and the ReceiverTracker to communicate
@@ -37,7 +47,7 @@ private[streaming] sealed trait ReceiverTrackerMessage
 private[streaming] case class RegisterReceiver(
     streamId: Int,
     typ: String,
-    host: String,
+    hostPort: String,
     receiverEndpoint: RpcEndpointRef
   ) extends ReceiverTrackerMessage
 private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo)
@@ -46,7 +56,38 @@ private[streaming] case class ReportError(streamId: Int, message: String, error:
 private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String)
   extends ReceiverTrackerMessage
 
-private[streaming] case object StopAllReceivers extends ReceiverTrackerMessage
+/**
+ * Messages used by the driver and ReceiverTrackerEndpoint to communicate locally.
+ */
+private[streaming] sealed trait ReceiverTrackerLocalMessage
+
+/**
+ * This message will trigger ReceiverTrackerEndpoint to restart a Spark job for the receiver.
+ */
+private[streaming] case class RestartReceiver(receiver: Receiver[_])
+  extends ReceiverTrackerLocalMessage
+
+/**
+ * This message is sent to ReceiverTrackerEndpoint when we start to launch Spark jobs for receivers
+ * at the first time.
+ */
+private[streaming] case class StartAllReceivers(receiver: Seq[Receiver[_]])
+  extends ReceiverTrackerLocalMessage
+
+/**
+ * This message will trigger ReceiverTrackerEndpoint to send stop signals to all registered
+ * receivers.
+ */
+private[streaming] case object StopAllReceivers extends ReceiverTrackerLocalMessage
+
+/**
+ * A message used by ReceiverTracker to ask all receiver's ids still stored in
+ * ReceiverTrackerEndpoint.
+ */
+private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessage
+
+private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long)
+  extends ReceiverTrackerLocalMessage
 
 /**
  * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of
@@ -60,8 +101,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
 
   private val receiverInputStreams = ssc.graph.getReceiverInputStreams()
   private val receiverInputStreamIds = receiverInputStreams.map { _.id }
-  private val receiverExecutor = new ReceiverLauncher()
-  private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo]
   private val receivedBlockTracker = new ReceivedBlockTracker(
     ssc.sparkContext.conf,
     ssc.sparkContext.hadoopConfiguration,
@@ -86,6 +125,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
   // This not being null means the tracker has been started and not stopped
   private var endpoint: RpcEndpointRef = null
 
+  private val schedulingPolicy = new ReceiverSchedulingPolicy()
+
+  // Track the active receiver job number. When a receiver job exits ultimately, countDown will
+  // be called.
+  private val receiverJobExitLatch = new CountDownLatch(receiverInputStreams.size)
+
+  /**
+   * Track all receivers' information. The key is the receiver id, the value is the receiver info.
+   * It's only accessed in ReceiverTrackerEndpoint.
+   */
+  private val receiverTrackingInfos = new HashMap[Int, ReceiverTrackingInfo]
+
+  /**
+   * Store all preferred locations for all receivers. We need this information to schedule
+   * receivers. It's only accessed in ReceiverTrackerEndpoint.
+   */
+  private val receiverPreferredLocations = new HashMap[Int, Option[String]]
+
   /** Start the endpoint and receiver execution thread. */
   def start(): Unit = synchronized {
     if (isTrackerStarted) {
@@ -95,7 +152,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
     if (!receiverInputStreams.isEmpty) {
       endpoint = ssc.env.rpcEnv.setupEndpoint(
         "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv))
-      if (!skipReceiverLaunch) receiverExecutor.start()
+      if (!skipReceiverLaunch) launchReceivers()
       logInfo("ReceiverTracker started")
       trackerState = Started
     }
@@ -112,20 +169,18 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
 
         // Wait for the Spark job that runs the receivers to be over
         // That is, for the receivers to quit gracefully.
-        receiverExecutor.awaitTermination(10000)
+        receiverJobExitLatch.await(10, TimeUnit.SECONDS)
 
         if (graceful) {
-          val pollTime = 100
           logInfo("Waiting for receiver job to terminate gracefully")
-          while (receiverInfo.nonEmpty || receiverExecutor.running) {
-            Thread.sleep(pollTime)
-          }
+          receiverJobExitLatch.await()
           logInfo("Waited for receiver job to terminate gracefully")
         }
 
         // Check if all the receivers have been deregistered or not
-        if (receiverInfo.nonEmpty) {
-          logWarning("Not all of the receivers have deregistered, " + receiverInfo)
+        val receivers = endpoint.askWithRetry[Seq[Int]](AllReceiverIds)
+        if (receivers.nonEmpty) {
+          logWarning("Not all of the receivers have deregistered, " + receivers)
         } else {
           logInfo("All of the receivers have deregistered successfully")
         }
@@ -154,9 +209,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
 
   /** Get the blocks allocated to the given batch and stream. */
   def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = {
-    synchronized {
-      receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId)
-    }
+    receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId)
   }
 
   /**
@@ -170,8 +223,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
     // Signal the receivers to delete old block data
     if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) {
       logInfo(s"Cleanup old received batch data: $cleanupThreshTime")
-      receiverInfo.values.flatMap { info => Option(info.endpoint) }
-        .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) }
+      endpoint.send(CleanupOldBlocks(cleanupThreshTime))
     }
   }
 
@@ -179,7 +231,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
   private def registerReceiver(
       streamId: Int,
       typ: String,
-      host: String,
+      hostPort: String,
       receiverEndpoint: RpcEndpointRef,
       senderAddress: RpcAddress
     ): Boolean = {
@@ -189,13 +241,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
 
     if (isTrackerStopping || isTrackerStopped) {
       false
+    } else if (!scheduleReceiver(streamId).contains(hostPort)) {
+      // Refuse it since it's scheduled to a wrong executor
+      false
     } else {
-      // "stopReceivers" won't happen at the same time because both "registerReceiver" and are
-      // called in the event loop. So here we can assume "stopReceivers" has not yet been called. If
-      // "stopReceivers" is called later, it should be able to see this receiver.
-      receiverInfo(streamId) = ReceiverInfo(
-        streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
-      listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
+      val name = s"${typ}-${streamId}"
+      val receiverTrackingInfo = ReceiverTrackingInfo(
+        streamId,
+        ReceiverState.ACTIVE,
+        scheduledExecutors = None,
+        runningExecutor = Some(hostPort),
+        name = Some(name),
+        endpoint = Some(receiverEndpoint))
+      receiverTrackingInfos.put(streamId, receiverTrackingInfo)
+      listenerBus.post(StreamingListenerReceiverStarted(receiverTrackingInfo.toReceiverInfo))
       logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
       true
     }
@@ -203,21 +262,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
 
   /** Deregister a receiver */
   private def deregisterReceiver(streamId: Int, message: String, error: String) {
-    val newReceiverInfo = receiverInfo.get(streamId) match {
+    val lastErrorTime =
+      if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis()
+    val errorInfo = ReceiverErrorInfo(
+      lastErrorMessage = message, lastError = error, lastErrorTime = lastErrorTime)
+    val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match {
       case Some(oldInfo) =>
-        val lastErrorTime =
-          if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis()
-        oldInfo.copy(endpoint = null, active = false, lastErrorMessage = message,
-          lastError = error, lastErrorTime = lastErrorTime)
+        oldInfo.copy(state = ReceiverState.INACTIVE, errorInfo = Some(errorInfo))
       case None =>
         logWarning("No prior receiver info")
-        val lastErrorTime =
-          if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis()
-        ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message,
-          lastError = error, lastErrorTime = lastErrorTime)
+        ReceiverTrackingInfo(
+          streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo))
     }
-    receiverInfo -= streamId
-    listenerBus.post(StreamingListenerReceiverStopped(newReceiverInfo))
+    receiverTrackingInfos -= streamId
+    listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo))
     val messageWithError = if (error != null && !error.isEmpty) {
       s"$message - $error"
     } else {
@@ -228,9 +286,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
 
   /** Update a receiver's maximum ingestion rate */
   def sendRateUpdate(streamUID: Int, newRate: Long): Unit = {
-    for (info <- receiverInfo.get(streamUID); eP <- Option(info.endpoint)) {
-      eP.send(UpdateRateLimit(newRate))
-    }
+    endpoint.send(UpdateReceiverRateLimit(streamUID, newRate))
   }
 
   /** Add new blocks for the given stream */
@@ -240,16 +296,21 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
 
   /** Report error sent by a receiver */
   private def reportError(streamId: Int, message: String, error: String) {
-    val newReceiverInfo = receiverInfo.get(streamId) match {
+    val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match {
       case Some(oldInfo) =>
-        oldInfo.copy(lastErrorMessage = message, lastError = error)
+        val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error,
+          lastErrorTime = oldInfo.errorInfo.map(_.lastErrorTime).getOrElse(-1L))
+        oldInfo.copy(errorInfo = Some(errorInfo))
       case None =>
         logWarning("No prior receiver info")
-        ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message,
-          lastError = error, lastErrorTime = ssc.scheduler.clock.getTimeMillis())
+        val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error,
+          lastErrorTime = ssc.scheduler.clock.getTimeMillis())
+        ReceiverTrackingInfo(
+          streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo))
     }
-    receiverInfo(streamId) = newReceiverInfo
-    listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId)))
+
+    receiverTrackingInfos(streamId) = newReceiverTrackingInfo
+    listenerBus.post(StreamingListenerReceiverError(newReceiverTrackingInfo.toReceiverInfo))
     val messageWithError = if (error != null && !error.isEmpty) {
       s"$message - $error"
     } else {
@@ -258,171 +319,242 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
     logWarning(s"Error reported by receiver for stream $streamId: $messageWithError")
   }
 
+  private def scheduleReceiver(receiverId: Int): Seq[String] = {
+    val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None)
+    val scheduledExecutors = schedulingPolicy.rescheduleReceiver(
+      receiverId, preferredLocation, receiverTrackingInfos, getExecutors)
+    updateReceiverScheduledExecutors(receiverId, scheduledExecutors)
+    scheduledExecutors
+  }
+
+  private def updateReceiverScheduledExecutors(
+      receiverId: Int, scheduledExecutors: Seq[String]): Unit = {
+    val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) match {
+      case Some(oldInfo) =>
+        oldInfo.copy(state = ReceiverState.SCHEDULED,
+          scheduledExecutors = Some(scheduledExecutors))
+      case None =>
+        ReceiverTrackingInfo(
+          receiverId,
+          ReceiverState.SCHEDULED,
+          Some(scheduledExecutors),
+          runningExecutor = None)
+    }
+    receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo)
+  }
+
   /** Check if any blocks are left to be processed */
   def hasUnallocatedBlocks: Boolean = {
     receivedBlockTracker.hasUnallocatedReceivedBlocks
   }
 
+  /**
+   * Get the list of executors excluding driver
+   */
+  private def getExecutors: Seq[String] = {
+    if (ssc.sc.isLocal) {
+      Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort)
+    } else {
+      ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) =>
+        blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location
+      }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq
+    }
+  }
+
+  /**
+   * Run the dummy Spark job to ensure that all slaves have registered. This avoids all the
+   * receivers to be scheduled on the same node.
+   *
+   * TODO Should poll the executor number and wait for executors according to
+   * "spark.scheduler.minRegisteredResourcesRatio" and
+   * "spark.scheduler.maxRegisteredResourcesWaitingTime" rather than running a dummy job.
+   */
+  private def runDummySparkJob(): Unit = {
+    if (!ssc.sparkContext.isLocal) {
+      ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect()
+    }
+    assert(getExecutors.nonEmpty)
+  }
+
+  /**
+   * Get the receivers from the ReceiverInputDStreams, distributes them to the
+   * worker nodes as a parallel collection, and runs them.
+   */
+  private def launchReceivers(): Unit = {
+    val receivers = receiverInputStreams.map(nis => {
+      val rcvr = nis.getReceiver()
+      rcvr.setReceiverId(nis.id)
+      rcvr
+    })
+
+    runDummySparkJob()
+
+    logInfo("Starting " + receivers.length + " receivers")
+    endpoint.send(StartAllReceivers(receivers))
+  }
+
+  /** Check if tracker has been marked for starting */
+  private def isTrackerStarted: Boolean = trackerState == Started
+
+  /** Check if tracker has been marked for stopping */
+  private def isTrackerStopping: Boolean = trackerState == Stopping
+
+  /** Check if tracker has been marked for stopped */
+  private def isTrackerStopped: Boolean = trackerState == Stopped
+
   /** RpcEndpoint to receive messages from the receivers. */
   private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint {
 
+    // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged
+    private val submitJobThreadPool = ExecutionContext.fromExecutorService(
+      ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool"))
+
     override def receive: PartialFunction[Any, Unit] = {
+      // Local messages
+      case StartAllReceivers(receivers) =>
+        val scheduledExecutors = schedulingPolicy.scheduleReceivers(receivers, getExecutors)
+        for (receiver <- receivers) {
+          val executors = scheduledExecutors(receiver.streamId)
+          updateReceiverScheduledExecutors(receiver.streamId, executors)
+          receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation
+          startReceiver(receiver, executors)
+        }
+      case RestartReceiver(receiver) =>
+        val scheduledExecutors = schedulingPolicy.rescheduleReceiver(
+          receiver.streamId,
+          receiver.preferredLocation,
+          receiverTrackingInfos,
+          getExecutors)
+        updateReceiverScheduledExecutors(receiver.streamId, scheduledExecutors)
+        startReceiver(receiver, scheduledExecutors)
+      case c: CleanupOldBlocks =>
+        receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c))
+      case UpdateReceiverRateLimit(streamUID, newRate) =>
+        for (info <- receiverTrackingInfos.get(streamUID); eP <- info.endpoint) {
+          eP.send(UpdateRateLimit(newRate))
+        }
+      // Remote messages
       case ReportError(streamId, message, error) =>
         reportError(streamId, message, error)
     }
 
     override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
-      case RegisterReceiver(streamId, typ, host, receiverEndpoint) =>
+      // Remote messages
+      case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) =>
         val successful =
-          registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
+          registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address)
         context.reply(successful)
       case AddBlock(receivedBlockInfo) =>
         context.reply(addBlock(receivedBlockInfo))
       case DeregisterReceiver(streamId, message, error) =>
         deregisterReceiver(streamId, message, error)
         context.reply(true)
+      // Local messages
+      case AllReceiverIds =>
+        context.reply(receiverTrackingInfos.keys.toSeq)
       case StopAllReceivers =>
         assert(isTrackerStopping || isTrackerStopped)
         stopReceivers()
         context.reply(true)
     }
 
-    /** Send stop signal to the receivers. */
-    private def stopReceivers() {
-      // Signal the receivers to stop
-      receiverInfo.values.flatMap { info => Option(info.endpoint)}
-        .foreach { _.send(StopReceiver) }
-      logInfo("Sent stop signal to all " + receiverInfo.size + " receivers")
-    }
-  }
-
-  /** This thread class runs all the receivers on the cluster.  */
-  class ReceiverLauncher {
-    @transient val env = ssc.env
-    @volatile @transient var running = false
-    @transient val thread = new Thread() {
-      override def run() {
-        try {
-          SparkEnv.set(env)
-          startReceivers()
-        } catch {
-          case ie: InterruptedException => logInfo("ReceiverLauncher interrupted")
-        }
-      }
-    }
-
-    def start() {
-      thread.start()
-    }
-
     /**
-     * Get the list of executors excluding driver
-     */
-    private def getExecutors(ssc: StreamingContext): List[String] = {
-      val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(":")(0)).toList
-      val driver = ssc.sparkContext.getConf.get("spark.driver.host")
-      executors.diff(List(driver))
-    }
-
-    /** Set host location(s) for each receiver so as to distribute them over
-     * executors in a round-robin fashion taking into account preferredLocation if set
+     * Start a receiver along with its scheduled executors
      */
-    private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]],
-      executors: List[String]): Array[ArrayBuffer[String]] = {
-      val locations = new Array[ArrayBuffer[String]](receivers.length)
-      var i = 0
-      for (i <- 0 until receivers.length) {
-        locations(i) = new ArrayBuffer[String]()
-        if (receivers(i).preferredLocation.isDefined) {
-          locations(i) += receivers(i).preferredLocation.get
-        }
+    private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = {
+      val receiverId = receiver.streamId
+      if (!isTrackerStarted) {
+        onReceiverJobFinish(receiverId)
+        return
       }
-      var count = 0
-      for (i <- 0 until max(receivers.length, executors.length)) {
-        if (!receivers(i % receivers.length).preferredLocation.isDefined) {
-          locations(i % receivers.length) += executors(count)
-          count += 1
-          if (count == executors.length) {
-            count = 0
-          }
-        }
-      }
-      locations
-    }
-
-    /**
-     * Get the receivers from the ReceiverInputDStreams, distributes them to the
-     * worker nodes as a parallel collection, and runs them.
-     */
-    private def startReceivers() {
-      val receivers = receiverInputStreams.map(nis => {
-        val rcvr = nis.getReceiver()
-        rcvr.setReceiverId(nis.id)
-        rcvr
-      })
 
       val checkpointDirOption = Option(ssc.checkpointDir)
       val serializableHadoopConf =
         new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration)
 
       // Function to start the receiver on the worker node
-      val startReceiver = (iterator: Iterator[Receiver[_]]) => {
-        if (!iterator.hasNext) {
-          throw new SparkException(
-            "Could not start receiver as object not found.")
-        }
-        val receiver = iterator.next()
-        val supervisor = new ReceiverSupervisorImpl(
-          receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption)
-        supervisor.start()
-        supervisor.awaitTermination()
-      }
-
-      // Run the dummy Spark job to ensure that all slaves have registered.
-      // This avoids all the receivers to be scheduled on the same node.
-      if (!ssc.sparkContext.isLocal) {
-        ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect()
-      }
+      val startReceiverFunc = new StartReceiverFunc(checkpointDirOption, serializableHadoopConf)
 
-      // Get the list of executors and schedule receivers
-      val executors = getExecutors(ssc)
-      val tempRDD =
-        if (!executors.isEmpty) {
-          val locations = scheduleReceivers(receivers, executors)
-          val roundRobinReceivers = (0 until receivers.length).map(i =>
-            (receivers(i), locations(i)))
-          ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers)
+      // Create the RDD using the scheduledExecutors to run the receiver in a Spark job
+      val receiverRDD: RDD[Receiver[_]] =
+        if (scheduledExecutors.isEmpty) {
+          ssc.sc.makeRDD(Seq(receiver), 1)
         } else {
-          ssc.sc.makeRDD(receivers, receivers.size)
+          ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors))
         }
+      receiverRDD.setName(s"Receiver $receiverId")
+      val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit](
+        receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ())
+      // We will keep restarting the receiver job until ReceiverTracker is stopped
+      future.onComplete {
+        case Success(_) =>
+          if (!isTrackerStarted) {
+            onReceiverJobFinish(receiverId)
+          } else {
+            logInfo(s"Restarting Receiver $receiverId")
+            self.send(RestartReceiver(receiver))
+          }
+        case Failure(e) =>
+          if (!isTrackerStarted) {
+            onReceiverJobFinish(receiverId)
+          } else {
+            logError("Receiver has been stopped. Try to restart it.", e)
+            logInfo(s"Restarting Receiver $receiverId")
+            self.send(RestartReceiver(receiver))
+          }
+      }(submitJobThreadPool)
+      logInfo(s"Receiver ${receiver.streamId} started")
+    }
 
-      // Distribute the receivers and start them
-      logInfo("Starting " + receivers.length + " receivers")
-      running = true
-      try {
-        ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
-        logInfo("All of the receivers have been terminated")
-      } finally {
-        running = false
-      }
+    override def onStop(): Unit = {
+      submitJobThreadPool.shutdownNow()
     }
 
     /**
-     * Wait until the Spark job that runs the receivers is terminated, or return when
-     * `milliseconds` elapses
+     * Call when a receiver is terminated. It means we won't restart its Spark job.
      */
-    def awaitTermination(milliseconds: Long): Unit = {
-      thread.join(milliseconds)
+    private def onReceiverJobFinish(receiverId: Int): Unit = {
+      receiverJobExitLatch.countDown()
+      receiverTrackingInfos.remove(receiverId).foreach { receiverTrackingInfo =>
+        if (receiverTrackingInfo.state == ReceiverState.ACTIVE) {
+          logWarning(s"Receiver $receiverId exited but didn't deregister")
+        }
+      }
     }
-  }
 
-  /** Check if tracker has been marked for starting */
-  private def isTrackerStarted(): Boolean = trackerState == Started
+    /** Send stop signal to the receivers. */
+    private def stopReceivers() {
+      receiverTrackingInfos.values.flatMap(_.endpoint).foreach { _.send(StopReceiver) }
+      logInfo("Sent stop signal to all " + receiverTrackingInfos.size + " receivers")
+    }
+  }
 
-  /** Check if tracker has been marked for stopping */
-  private def isTrackerStopping(): Boolean = trackerState == Stopping
+}
 
-  /** Check if tracker has been marked for stopped */
-  private def isTrackerStopped(): Boolean = trackerState == Stopped
+/**
+ * Function to start the receiver on the worker node. Use a class instead of closure to avoid
+ * the serialization issue.
+ */
+private class StartReceiverFunc(
+    checkpointDirOption: Option[String],
+    serializableHadoopConf: SerializableConfiguration)
+  extends (Iterator[Receiver[_]] => Unit) with Serializable {
+
+  override def apply(iterator: Iterator[Receiver[_]]): Unit = {
+    if (!iterator.hasNext) {
+      throw new SparkException(
+        "Could not start receiver as object not found.")
+    }
+    if (TaskContext.get().attemptNumber() == 0) {
+      val receiver = iterator.next()
+      assert(iterator.hasNext == false)
+      val supervisor = new ReceiverSupervisorImpl(
+        receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption)
+      supervisor.start()
+      supervisor.awaitTermination()
+    } else {
+      // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it.
+    }
+  }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/daa1964b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala
new file mode 100644
index 0000000..043ff4d
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.streaming.scheduler.ReceiverState._
+
+private[streaming] case class ReceiverErrorInfo(
+    lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L)
+
+/**
+ * Class having information about a receiver.
+ *
+ * @param receiverId the unique receiver id
+ * @param state the current Receiver state
+ * @param scheduledExecutors the scheduled executors provided by ReceiverSchedulingPolicy
+ * @param runningExecutor the running executor if the receiver is active
+ * @param name the receiver name
+ * @param endpoint the receiver endpoint. It can be used to send messages to the receiver
+ * @param errorInfo the receiver error information if it fails
+ */
+private[streaming] case class ReceiverTrackingInfo(
+    receiverId: Int,
+    state: ReceiverState,
+    scheduledExecutors: Option[Seq[String]],
+    runningExecutor: Option[String],
+    name: Option[String] = None,
+    endpoint: Option[RpcEndpointRef] = None,
+    errorInfo: Option[ReceiverErrorInfo] = None) {
+
+  def toReceiverInfo: ReceiverInfo = ReceiverInfo(
+    receiverId,
+    name.getOrElse(""),
+    state == ReceiverState.ACTIVE,
+    location = runningExecutor.getOrElse(""),
+    lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""),
+    lastError = errorInfo.map(_.lastError).getOrElse(""),
+    lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L)
+  )
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/daa1964b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
new file mode 100644
index 0000000..93f920f
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkFunSuite
+
+class ReceiverSchedulingPolicySuite extends SparkFunSuite {
+
+  val receiverSchedulingPolicy = new ReceiverSchedulingPolicy
+
+  test("rescheduleReceiver: empty executors") {
+    val scheduledExecutors =
+      receiverSchedulingPolicy.rescheduleReceiver(0, None, Map.empty, executors = Seq.empty)
+    assert(scheduledExecutors === Seq.empty)
+  }
+
+  test("rescheduleReceiver: receiver preferredLocation") {
+    val receiverTrackingInfoMap = Map(
+      0 -> ReceiverTrackingInfo(0, ReceiverState.INACTIVE, None, None))
+    val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver(
+      0, Some("host1"), receiverTrackingInfoMap, executors = Seq("host2"))
+    assert(scheduledExecutors.toSet === Set("host1", "host2"))
+  }
+
+  test("rescheduleReceiver: return all idle executors if more than 3 idle executors") {
+    val executors = Seq("host1", "host2", "host3", "host4", "host5")
+    // host3 is idle
+    val receiverTrackingInfoMap = Map(
+      0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")))
+    val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver(
+      1, None, receiverTrackingInfoMap, executors)
+    assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5"))
+  }
+
+  test("rescheduleReceiver: return 3 best options if less than 3 idle executors") {
+    val executors = Seq("host1", "host2", "host3", "host4", "host5")
+    // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0
+    // host4 and host5 are idle
+    val receiverTrackingInfoMap = Map(
+      0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")),
+      1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None),
+      2 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None))
+    val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver(
+      3, None, receiverTrackingInfoMap, executors)
+    assert(scheduledExecutors.toSet === Set("host2", "host4", "host5"))
+  }
+
+  test("scheduleReceivers: " +
+    "schedule receivers evenly when there are more receivers than executors") {
+    val receivers = (0 until 6).map(new DummyReceiver(_))
+    val executors = (10000 until 10003).map(port => s"localhost:${port}")
+    val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
+    val numReceiversOnExecutor = mutable.HashMap[String, Int]()
+    // There should be 2 receivers running on each executor and each receiver has one executor
+    scheduledExecutors.foreach { case (receiverId, executors) =>
+      assert(executors.size == 1)
+      numReceiversOnExecutor(executors(0)) = numReceiversOnExecutor.getOrElse(executors(0), 0) + 1
+    }
+    assert(numReceiversOnExecutor === executors.map(_ -> 2).toMap)
+  }
+
+
+  test("scheduleReceivers: " +
+    "schedule receivers evenly when there are more executors than receivers") {
+    val receivers = (0 until 3).map(new DummyReceiver(_))
+    val executors = (10000 until 10006).map(port => s"localhost:${port}")
+    val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
+    val numReceiversOnExecutor = mutable.HashMap[String, Int]()
+    // There should be 1 receiver running on each executor and each receiver has two executors
+    scheduledExecutors.foreach { case (receiverId, executors) =>
+      assert(executors.size == 2)
+      executors.foreach { l =>
+        numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1
+      }
+    }
+    assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap)
+  }
+
+  test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") {
+    val receivers = (0 until 3).map(new DummyReceiver(_)) ++
+      (3 until 6).map(new DummyReceiver(_, Some("localhost")))
+    val executors = (10000 until 10003).map(port => s"localhost:${port}") ++
+      (10003 until 10006).map(port => s"localhost2:${port}")
+    val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors)
+    val numReceiversOnExecutor = mutable.HashMap[String, Int]()
+    // There should be 1 receiver running on each executor and each receiver has 1 executor
+    scheduledExecutors.foreach { case (receiverId, executors) =>
+      assert(executors.size == 1)
+      executors.foreach { l =>
+        numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1
+      }
+    }
+    assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap)
+    // Make sure we schedule the receivers to their preferredLocations
+    val executorsForReceiversWithPreferredLocation =
+      scheduledExecutors.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2)
+    // We can simply check the executor set because we only know each receiver only has 1 executor
+    assert(executorsForReceiversWithPreferredLocation.toSet ===
+      (10000 until 10003).map(port => s"localhost:${port}").toSet)
+  }
+
+  test("scheduleReceivers: return empty if no receiver") {
+    assert(receiverSchedulingPolicy.scheduleReceivers(Seq.empty, Seq("localhost:10000")).isEmpty)
+  }
+
+  test("scheduleReceivers: return empty scheduled executors if no executors") {
+    val receivers = (0 until 3).map(new DummyReceiver(_))
+    val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty)
+    scheduledExecutors.foreach { case (receiverId, executors) =>
+      assert(executors.isEmpty)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/daa1964b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index aadb723..e2159bd 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -18,66 +18,18 @@
 package org.apache.spark.streaming.scheduler
 
 import org.scalatest.concurrent.Eventually._
-import org.scalatest.concurrent.Timeouts
 import org.scalatest.time.SpanSugar._
-import org.apache.spark.streaming._
+
 import org.apache.spark.SparkConf
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming._
 import org.apache.spark.streaming.receiver._
-import org.apache.spark.util.Utils
-import org.apache.spark.streaming.dstream.InputDStream
-import scala.reflect.ClassTag
 import org.apache.spark.streaming.dstream.ReceiverInputDStream
+import org.apache.spark.storage.StorageLevel
 
 /** Testsuite for receiver scheduling */
 class ReceiverTrackerSuite extends TestSuiteBase {
   val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test")
   val ssc = new StreamingContext(sparkConf, Milliseconds(100))
-  val tracker = new ReceiverTracker(ssc)
-  val launcher = new tracker.ReceiverLauncher()
-  val executors: List[String] = List("0", "1", "2", "3")
-
-  test("receiver scheduling - all or none have preferred location") {
-
-    def parse(s: String): Array[Array[String]] = {
-      val outerSplit = s.split("\\|")
-      val loc = new Array[Array[String]](outerSplit.length)
-      var i = 0
-      for (i <- 0 until outerSplit.length) {
-        loc(i) = outerSplit(i).split("\\,")
-      }
-      loc
-    }
-
-    def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) {
-      val receivers =
-        if (preferredLocation) {
-          Array.tabulate(numReceivers)(i => new DummyReceiver(host =
-            Some(((i + 1) % executors.length).toString)))
-        } else {
-          Array.tabulate(numReceivers)(_ => new DummyReceiver)
-        }
-      val locations = launcher.scheduleReceivers(receivers, executors)
-      val expectedLocations = parse(allocation)
-      assert(locations.deep === expectedLocations.deep)
-    }
-
-    testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0")
-    testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2")
-    testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0")
-  }
-
-  test("receiver scheduling - some have preferred location") {
-    val numReceivers = 4;
-    val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")),
-      new DummyReceiver, new DummyReceiver, new DummyReceiver)
-    val locations = launcher.scheduleReceivers(receivers, executors)
-    assert(locations(0)(0) === "1")
-    assert(locations(1)(0) === "0")
-    assert(locations(2)(0) === "1")
-    assert(locations(0).length === 1)
-    assert(locations(3).length === 1)
-  }
 
   test("Receiver tracker - propagates rate limit") {
     object ReceiverStartedWaiter extends StreamingListener {
@@ -134,19 +86,19 @@ private class RateLimitInputDStream(@transient ssc_ : StreamingContext)
  * @note It's necessary to be a top-level object, or else serialization would create another
  *       one on the executor side and we won't be able to read its rate limit.
  */
-private object SingletonDummyReceiver extends DummyReceiver
+private object SingletonDummyReceiver extends DummyReceiver(0)
 
 /**
  * Dummy receiver implementation
  */
-private class DummyReceiver(host: Option[String] = None)
+private class DummyReceiver(receiverId: Int, host: Option[String] = None)
   extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
 
-  def onStart() {
-  }
+  setReceiverId(receiverId)
 
-  def onStop() {
-  }
+  override def onStart(): Unit = {}
+
+  override def onStop(): Unit = {}
 
   override def preferredLocation: Option[String] = host
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/daa1964b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
index 40dc1fb..0891309 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
@@ -119,20 +119,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
     listener.numTotalReceivedRecords should be (600)
 
     // onReceiverStarted
-    val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost")
+    val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost")
     listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted))
     listener.receiverInfo(0) should be (Some(receiverInfoStarted))
     listener.receiverInfo(1) should be (None)
 
     // onReceiverError
-    val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost")
+    val receiverInfoError = ReceiverInfo(1, "test", true, "localhost")
     listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError))
     listener.receiverInfo(0) should be (Some(receiverInfoStarted))
     listener.receiverInfo(1) should be (Some(receiverInfoError))
     listener.receiverInfo(2) should be (None)
 
     // onReceiverStopped
-    val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost")
+    val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost")
     listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped))
     listener.receiverInfo(0) should be (Some(receiverInfoStarted))
     listener.receiverInfo(1) should be (Some(receiverInfoError))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org