You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/04/08 09:00:33 UTC

git commit: [SPARK-1331] Added graceful shutdown to Spark Streaming

Repository: spark
Updated Branches:
  refs/heads/master 11eabbe12 -> 83ac9a4bb


[SPARK-1331] Added graceful shutdown to Spark Streaming

Current version of StreamingContext.stop() directly kills all the data receivers (NetworkReceiver) without waiting for the data already received to be persisted and processed. This PR provides the fix. Now, when the StreamingContext.stop() is called, the following sequence of steps will happen.
1. The driver will send a stop signal to all the active receivers.
2. Each receiver, when it gets a stop signal from the driver, first stop receiving more data, then waits for the thread that persists data blocks to BlockManager to finish persisting all receive data, and finally quits.
3. After all the receivers have stopped, the driver will wait for the Job Generator and Job Scheduler to finish processing all the received data.

It also fixes the semantics of StreamingContext.start and stop. It will throw appropriate errors and warnings if stop() is called before start(), stop() is called twice, etc.

Author: Tathagata Das <ta...@gmail.com>

Closes #247 from tdas/graceful-shutdown and squashes the following commits:

61c0016 [Tathagata Das] Updated MIMA binary check excludes.
ae1d39b [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into graceful-shutdown
6b59cfc [Tathagata Das] Minor changes based on Andrew's comment on PR.
d0b8d65 [Tathagata Das] Reduced time taken by graceful shutdown unit test.
f55bc67 [Tathagata Das] Fix scalastyle
c69b3a7 [Tathagata Das] Updates based on Patrick's comments.
c43b8ae [Tathagata Das] Added graceful shutdown to Spark Streaming.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/83ac9a4b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/83ac9a4b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/83ac9a4b

Branch: refs/heads/master
Commit: 83ac9a4bbf272028d0c4639cbd1e12022b9ae77a
Parents: 11eabbe
Author: Tathagata Das <ta...@gmail.com>
Authored: Tue Apr 8 00:00:17 2014 -0700
Committer: Patrick Wendell <pw...@gmail.com>
Committed: Tue Apr 8 00:00:17 2014 -0700

----------------------------------------------------------------------
 project/MimaBuild.scala                         |  24 +--
 .../org/apache/spark/streaming/Checkpoint.scala |  14 +-
 .../spark/streaming/StreamingContext.scala      |  48 +++++-
 .../api/java/JavaStreamingContext.scala         |  12 +-
 .../streaming/dstream/NetworkInputDStream.scala | 151 ++++++++++++------
 .../streaming/dstream/SocketInputDStream.scala  |   1 -
 .../streaming/receivers/ActorReceiver.scala     |   2 +-
 .../streaming/scheduler/JobGenerator.scala      | 124 +++++++++++----
 .../streaming/scheduler/JobScheduler.scala      |  56 ++++---
 .../scheduler/NetworkInputTracker.scala         | 154 +++++++++++--------
 .../org/apache/spark/streaming/util/Clock.scala |   5 +-
 .../spark/streaming/util/RecurringTimer.scala   |  62 ++++++--
 .../spark/streaming/BasicOperationsSuite.scala  |   4 +-
 .../spark/streaming/StreamingContextSuite.scala | 108 +++++++++++--
 .../apache/spark/streaming/TestSuiteBase.scala  |   2 +-
 15 files changed, 552 insertions(+), 215 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/project/MimaBuild.scala
----------------------------------------------------------------------
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index e7c9c47..5ea4817 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -58,17 +58,19 @@ object MimaBuild {
       SparkBuild.SPARK_VERSION match {
         case v if v.startsWith("1.0") =>
           Seq(
-             excludePackage("org.apache.spark.api.java"),
-             excludePackage("org.apache.spark.streaming.api.java"),
-             excludePackage("org.apache.spark.mllib")
-           ) ++
-           excludeSparkClass("rdd.ClassTags") ++
-           excludeSparkClass("util.XORShiftRandom") ++
-           excludeSparkClass("mllib.recommendation.MFDataGenerator") ++
-           excludeSparkClass("mllib.optimization.SquaredGradient") ++
-           excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++
-           excludeSparkClass("mllib.regression.LassoWithSGD") ++
-           excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
+            excludePackage("org.apache.spark.api.java"),
+            excludePackage("org.apache.spark.streaming.api.java"),
+            excludePackage("org.apache.spark.mllib")
+          ) ++
+          excludeSparkClass("rdd.ClassTags") ++
+          excludeSparkClass("util.XORShiftRandom") ++
+          excludeSparkClass("mllib.recommendation.MFDataGenerator") ++
+          excludeSparkClass("mllib.optimization.SquaredGradient") ++
+          excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++
+          excludeSparkClass("mllib.regression.LassoWithSGD") ++
+          excludeSparkClass("mllib.regression.LinearRegressionWithSGD") ++
+          excludeSparkClass("streaming.dstream.NetworkReceiver") ++
+          excludeSparkClass("streaming.dstream.NetworkReceiver#NetworkReceiverActor")
         case _ => Seq()
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index baf80fe..93023e8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -194,19 +194,19 @@ class CheckpointWriter(
     }
   }
 
-  def stop() {
-    synchronized {
-      if (stopped) {
-        return
-      }
-      stopped = true
-    }
+  def stop(): Unit = synchronized {
+    if (stopped) return
+
     executor.shutdown()
     val startTime = System.currentTimeMillis()
     val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS)
+    if (!terminated) {
+      executor.shutdownNow()
+    }
     val endTime = System.currentTimeMillis()
     logInfo("CheckpointWriter executor terminated ? " + terminated +
       ", waited for " + (endTime - startTime) + " ms.")
+    stopped = true
   }
 
   private def fs = synchronized {

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index e198c69..a4e236c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -158,6 +158,15 @@ class StreamingContext private[streaming] (
 
   private[streaming] val waiter = new ContextWaiter
 
+  /** Enumeration to identify current state of the StreamingContext */
+  private[streaming] object StreamingContextState extends Enumeration {
+    type CheckpointState = Value
+    val Initialized, Started, Stopped = Value
+  }
+
+  import StreamingContextState._
+  private[streaming] var state = Initialized
+
   /**
    * Return the associated Spark context
    */
@@ -405,9 +414,18 @@ class StreamingContext private[streaming] (
   /**
    * Start the execution of the streams.
    */
-  def start() = synchronized {
+  def start(): Unit = synchronized {
+    // Throw exception if the context has already been started once
+    // or if a stopped context is being started again
+    if (state == Started) {
+      throw new SparkException("StreamingContext has already been started")
+    }
+    if (state == Stopped) {
+      throw new SparkException("StreamingContext has already been stopped")
+    }
     validate()
     scheduler.start()
+    state = Started
   }
 
   /**
@@ -428,14 +446,38 @@ class StreamingContext private[streaming] (
   }
 
   /**
-   * Stop the execution of the streams.
+   * Stop the execution of the streams immediately (does not wait for all received data
+   * to be processed).
    * @param stopSparkContext Stop the associated SparkContext or not
+   *
    */
   def stop(stopSparkContext: Boolean = true): Unit = synchronized {
-    scheduler.stop()
+    stop(stopSparkContext, false)
+  }
+
+  /**
+   * Stop the execution of the streams, with option of ensuring all received data
+   * has been processed.
+   * @param stopSparkContext Stop the associated SparkContext or not
+   * @param stopGracefully Stop gracefully by waiting for the processing of all
+   *                       received data to be completed
+   */
+  def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized {
+    // Warn (but not fail) if context is stopped twice,
+    // or context is stopped before starting
+    if (state == Initialized) {
+      logWarning("StreamingContext has not been started yet")
+      return
+    }
+    if (state == Stopped) {
+      logWarning("StreamingContext has already been stopped")
+      return
+    } // no need to throw an exception as its okay to stop twice
+    scheduler.stop(stopGracefully)
     logInfo("StreamingContext stopped successfully")
     waiter.notifyStop()
     if (stopSparkContext) sc.stop()
+    state = Stopped
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index b705d2e..c800602 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -509,8 +509,16 @@ class JavaStreamingContext(val ssc: StreamingContext) {
    * Stop the execution of the streams.
    * @param stopSparkContext Stop the associated SparkContext or not
    */
-  def stop(stopSparkContext: Boolean): Unit = {
-    ssc.stop(stopSparkContext)
+  def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext)
+
+  /**
+   * Stop the execution of the streams.
+   * @param stopSparkContext Stop the associated SparkContext or not
+   * @param stopGracefully Stop gracefully by waiting for the processing of all
+   *                       received data to be completed
+   */
+  def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = {
+    ssc.stop(stopSparkContext, stopGracefully)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index 72ad0ba..d19a635 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.streaming.dstream
 
-import java.util.concurrent.ArrayBlockingQueue
+import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
 import java.nio.ByteBuffer
 
 import scala.collection.mutable.ArrayBuffer
@@ -34,6 +34,7 @@ import org.apache.spark.{Logging, SparkEnv}
 import org.apache.spark.rdd.{RDD, BlockRDD}
 import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
 import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver}
+import org.apache.spark.util.AkkaUtils
 
 /**
  * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]]
@@ -69,7 +70,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
     // then this returns an empty RDD. This may happen when recovering from a
     // master failure
     if (validTime >= graph.startTime) {
-      val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime)
+      val blockIds = ssc.scheduler.networkInputTracker.getBlocks(id, validTime)
       Some(new BlockRDD[T](ssc.sc, blockIds))
     } else {
       Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
@@ -79,7 +80,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
 
 
 private[streaming] sealed trait NetworkReceiverMessage
-private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage
+private[streaming] case class StopReceiver() extends NetworkReceiverMessage
 private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any)
   extends NetworkReceiverMessage
 private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage
@@ -90,13 +91,31 @@ private[streaming] case class ReportError(msg: String) extends NetworkReceiverMe
  */
 abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging {
 
+  /** Local SparkEnv */
   lazy protected val env = SparkEnv.get
 
+  /** Remote Akka actor for the NetworkInputTracker */
+  lazy protected val trackerActor = {
+    val ip = env.conf.get("spark.driver.host", "localhost")
+    val port = env.conf.getInt("spark.driver.port", 7077)
+    val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
+    env.actorSystem.actorSelection(url)
+  }
+
+  /** Akka actor for receiving messages from the NetworkInputTracker in the driver */
   lazy protected val actor = env.actorSystem.actorOf(
     Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId)
 
+  /** Timeout for Akka actor messages */
+  lazy protected val askTimeout = AkkaUtils.askTimeout(env.conf)
+
+  /** Thread that starts the receiver and stays blocked while data is being received */
   lazy protected val receivingThread = Thread.currentThread()
 
+  /** Exceptions that occurs while receiving data */
+  protected lazy val exceptions = new ArrayBuffer[Exception]
+
+  /** Identifier of the stream this receiver is associated with */
   protected var streamId: Int = -1
 
   /**
@@ -112,7 +131,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
   def getLocationPreference() : Option[String] = None
 
   /**
-   * Starts the receiver. First is accesses all the lazy members to
+   * Start the receiver. First is accesses all the lazy members to
    * materialize them. Then it calls the user-defined onStart() method to start
    * other threads, etc required to receiver the data.
    */
@@ -124,83 +143,107 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
       receivingThread
 
       // Call user-defined onStart()
+      logInfo("Starting receiver")
       onStart()
+
+      // Wait until interrupt is called on this thread
+      while(true) Thread.sleep(100000)
     } catch {
       case ie: InterruptedException =>
-        logInfo("Receiving thread interrupted")
+        logInfo("Receiving thread has been interrupted, receiver "  + streamId + " stopped")
       case e: Exception =>
-        stopOnError(e)
+        logError("Error receiving data in receiver " + streamId, e)
+        exceptions += e
+    }
+
+    // Call user-defined onStop()
+    logInfo("Stopping receiver")
+    try {
+      onStop()
+    } catch {
+      case  e: Exception =>
+        logError("Error stopping receiver " + streamId, e)
+        exceptions += e
+    }
+
+    val message = if (exceptions.isEmpty) {
+      null
+    } else if (exceptions.size == 1) {
+      val e = exceptions.head
+      "Exception in receiver " + streamId + ": " + e.getMessage + "\n" + e.getStackTraceString
+    } else {
+      "Multiple exceptions in receiver " + streamId + "(" + exceptions.size + "):\n"
+        exceptions.zipWithIndex.map {
+          case (e, i) => "Exception " + i + ": " + e.getMessage + "\n" + e.getStackTraceString
+        }.mkString("\n")
     }
+    logInfo("Deregistering receiver " + streamId)
+    val future = trackerActor.ask(DeregisterReceiver(streamId, message))(askTimeout)
+    Await.result(future, askTimeout)
+    logInfo("Deregistered receiver " + streamId)
+    env.actorSystem.stop(actor)
+    logInfo("Stopped receiver " + streamId)
   }
 
   /**
-   * Stops the receiver. First it interrupts the main receiving thread,
-   * that is, the thread that called receiver.start(). Then it calls the user-defined
-   * onStop() method to stop other threads and/or do cleanup.
+   * Stop the receiver. First it interrupts the main receiving thread,
+   * that is, the thread that called receiver.start().
    */
   def stop() {
+    // Stop receiving by interrupting the receiving thread
     receivingThread.interrupt()
-    onStop()
-    // TODO: terminate the actor
+    logInfo("Interrupted receiving thread " + receivingThread + " for stopping")
   }
 
   /**
-   * Stops the receiver and reports exception to the tracker.
+   * Stop the receiver and reports exception to the tracker.
    * This should be called whenever an exception is to be handled on any thread
    * of the receiver.
    */
   protected def stopOnError(e: Exception) {
     logError("Error receiving data", e)
+    exceptions += e
     stop()
-    actor ! ReportError(e.toString)
   }
 
-
   /**
-   * Pushes a block (as an ArrayBuffer filled with data) into the block manager.
+   * Push a block (as an ArrayBuffer filled with data) into the block manager.
    */
   def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
     env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level)
-    actor ! ReportBlock(blockId, metadata)
+    trackerActor ! AddBlocks(streamId, Array(blockId), metadata)
+    logDebug("Pushed block " + blockId)
   }
 
   /**
-   * Pushes a block (as bytes) into the block manager.
+   * Push a block (as bytes) into the block manager.
    */
   def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
     env.blockManager.putBytes(blockId, bytes, level)
-    actor ! ReportBlock(blockId, metadata)
+    trackerActor ! AddBlocks(streamId, Array(blockId), metadata)
+  }
+
+  /** Set the ID of the DStream that this receiver is associated with */
+  protected[streaming] def setStreamId(id: Int) {
+    streamId = id
   }
 
   /** A helper actor that communicates with the NetworkInputTracker */
   private class NetworkReceiverActor extends Actor {
-    logInfo("Attempting to register with tracker")
-    val ip = env.conf.get("spark.driver.host", "localhost")
-    val port = env.conf.getInt("spark.driver.port", 7077)
-    val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
-    val tracker = env.actorSystem.actorSelection(url)
-    val timeout = 5.seconds
 
     override def preStart() {
-      val future = tracker.ask(RegisterReceiver(streamId, self))(timeout)
-      Await.result(future, timeout)
+      logInfo("Registered receiver " + streamId)
+      val future = trackerActor.ask(RegisterReceiver(streamId, self))(askTimeout)
+      Await.result(future, askTimeout)
     }
 
     override def receive() = {
-      case ReportBlock(blockId, metadata) =>
-        tracker ! AddBlocks(streamId, Array(blockId), metadata)
-      case ReportError(msg) =>
-        tracker ! DeregisterReceiver(streamId, msg)
-      case StopReceiver(msg) =>
+      case StopReceiver =>
+        logInfo("Received stop signal")
         stop()
-        tracker ! DeregisterReceiver(streamId, msg)
     }
   }
 
-  protected[streaming] def setStreamId(id: Int) {
-    streamId = id
-  }
-
   /**
    * Batches objects created by a [[org.apache.spark.streaming.dstream.NetworkReceiver]] and puts
    * them into appropriately named blocks at regular intervals. This class starts two threads,
@@ -214,23 +257,26 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
 
     val clock = new SystemClock()
     val blockInterval = env.conf.getLong("spark.streaming.blockInterval", 200)
-    val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
+    val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer,
+      "BlockGenerator")
     val blockStorageLevel = storageLevel
     val blocksForPushing = new ArrayBlockingQueue[Block](1000)
     val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
 
     var currentBuffer = new ArrayBuffer[T]
+    var stopped = false
 
     def start() {
       blockIntervalTimer.start()
       blockPushingThread.start()
-      logInfo("Data handler started")
+      logInfo("Started BlockGenerator")
     }
 
     def stop() {
-      blockIntervalTimer.stop()
-      blockPushingThread.interrupt()
-      logInfo("Data handler stopped")
+      blockIntervalTimer.stop(false)
+      stopped = true
+      blockPushingThread.join()
+      logInfo("Stopped BlockGenerator")
     }
 
     def += (obj: T): Unit = synchronized {
@@ -248,24 +294,35 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
         }
       } catch {
         case ie: InterruptedException =>
-          logInfo("Block interval timer thread interrupted")
+          logInfo("Block updating timer thread was interrupted")
         case e: Exception =>
-          NetworkReceiver.this.stop()
+          NetworkReceiver.this.stopOnError(e)
       }
     }
 
     private def keepPushingBlocks() {
-      logInfo("Block pushing thread started")
+      logInfo("Started block pushing thread")
       try {
-        while(true) {
+        while(!stopped) {
+          Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match {
+            case Some(block) =>
+              NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel)
+            case None =>
+          }
+        }
+        // Push out the blocks that are still left
+        logInfo("Pushing out the last " + blocksForPushing.size() + " blocks")
+        while (!blocksForPushing.isEmpty) {
           val block = blocksForPushing.take()
           NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel)
+          logInfo("Blocks left to push " + blocksForPushing.size())
         }
+        logInfo("Stopped blocks pushing thread")
       } catch {
         case ie: InterruptedException =>
-          logInfo("Block pushing thread interrupted")
+          logInfo("Block pushing thread was interrupted")
         case e: Exception =>
-          NetworkReceiver.this.stop()
+          NetworkReceiver.this.stopOnError(e)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
index 2cdd13f..63d94d1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
@@ -67,7 +67,6 @@ class SocketReceiver[T: ClassTag](
   protected def onStop() {
     blockGenerator.stop()
   }
-
 }
 
 private[streaming]

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
index bd78bae..44eb275 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
@@ -174,10 +174,10 @@ private[streaming] class ActorReceiver[T: ClassTag](
     blocksGenerator.start()
     supervisor
     logInfo("Supervision tree for receivers initialized at:" + supervisor.path)
+
   }
 
   protected def onStop() = {
     supervisor ! PoisonPill
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index c730624..92d885c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -39,16 +39,22 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
 
   private val ssc = jobScheduler.ssc
   private val graph = ssc.graph
+
   val clock = {
     val clockClass = ssc.sc.conf.get(
       "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
     Class.forName(clockClass).newInstance().asInstanceOf[Clock]
   }
+
   private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
-    longTime => eventActor ! GenerateJobs(new Time(longTime)))
-  private lazy val checkpointWriter =
-    if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
-      new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
+    longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator")
+
+  // This is marked lazy so that this is initialized after checkpoint duration has been set
+  // in the context and the generator has been started.
+  private lazy val shouldCheckpoint = ssc.checkpointDuration != null && ssc.checkpointDir != null
+
+  private lazy val checkpointWriter = if (shouldCheckpoint) {
+    new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
   } else {
     null
   }
@@ -57,17 +63,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
   // This not being null means the scheduler has been started and not stopped
   private var eventActor: ActorRef = null
 
+  // last batch whose completion,checkpointing and metadata cleanup has been completed
+  private var lastProcessedBatch: Time = null
+
   /** Start generation of jobs */
-  def start() = synchronized {
-    if (eventActor != null) {
-      throw new SparkException("JobGenerator already started")
-    }
+  def start(): Unit = synchronized {
+    if (eventActor != null) return // generator has already been started
 
     eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
       def receive = {
-        case event: JobGeneratorEvent =>
-          logDebug("Got event of type " + event.getClass.getName)
-          processEvent(event)
+        case event: JobGeneratorEvent =>  processEvent(event)
       }
     }), "JobGenerator")
     if (ssc.isCheckpointPresent) {
@@ -77,30 +82,79 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
     }
   }
 
-  /** Stop generation of jobs */
-  def stop() = synchronized {
-    if (eventActor != null) {
-      timer.stop()
-      ssc.env.actorSystem.stop(eventActor)
-      if (checkpointWriter != null) checkpointWriter.stop()
-      ssc.graph.stop()
-      logInfo("JobGenerator stopped")
+  /**
+   * Stop generation of jobs. processReceivedData = true makes this wait until jobs
+   * of current ongoing time interval has been generated, processed and corresponding
+   * checkpoints written.
+   */
+  def stop(processReceivedData: Boolean): Unit = synchronized {
+    if (eventActor == null) return // generator has already been stopped
+
+    if (processReceivedData) {
+      logInfo("Stopping JobGenerator gracefully")
+      val timeWhenStopStarted = System.currentTimeMillis()
+      val stopTimeout = 10 * ssc.graph.batchDuration.milliseconds
+      val pollTime = 100
+
+      // To prevent graceful stop to get stuck permanently
+      def hasTimedOut = {
+        val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout
+        if (timedOut) logWarning("Timed out while stopping the job generator")
+        timedOut
+      }
+
+      // Wait until all the received blocks in the network input tracker has
+      // been consumed by network input DStreams, and jobs have been generated with them
+      logInfo("Waiting for all received blocks to be consumed for job generation")
+      while(!hasTimedOut && jobScheduler.networkInputTracker.hasMoreReceivedBlockIds) {
+        Thread.sleep(pollTime)
+      }
+      logInfo("Waited for all received blocks to be consumed for job generation")
+
+      // Stop generating jobs
+      val stopTime = timer.stop(false)
+      graph.stop()
+      logInfo("Stopped generation timer")
+
+      // Wait for the jobs to complete and checkpoints to be written
+      def haveAllBatchesBeenProcessed = {
+        lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime
+      }
+      logInfo("Waiting for jobs to be processed and checkpoints to be written")
+      while (!hasTimedOut && !haveAllBatchesBeenProcessed) {
+        Thread.sleep(pollTime)
+      }
+      logInfo("Waited for jobs to be processed and checkpoints to be written")
+    } else {
+      logInfo("Stopping JobGenerator immediately")
+      // Stop timer and graph immediately, ignore unprocessed data and pending jobs
+      timer.stop(true)
+      graph.stop()
     }
+
+    // Stop the actor and checkpoint writer
+    if (shouldCheckpoint) checkpointWriter.stop()
+    ssc.env.actorSystem.stop(eventActor)
+    logInfo("Stopped JobGenerator")
   }
 
   /**
-   * On batch completion, clear old metadata and checkpoint computation.
+   * Callback called when a batch has been completely processed.
    */
   def onBatchCompletion(time: Time) {
     eventActor ! ClearMetadata(time)
   }
-  
+
+  /**
+   * Callback called when the checkpoint of a batch has been written.
+   */
   def onCheckpointCompletion(time: Time) {
     eventActor ! ClearCheckpointData(time)
   }
 
   /** Processes all events */
   private def processEvent(event: JobGeneratorEvent) {
+    logDebug("Got event " + event)
     event match {
       case GenerateJobs(time) => generateJobs(time)
       case ClearMetadata(time) => clearMetadata(time)
@@ -114,7 +168,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
     val startTime = new Time(timer.getStartTime())
     graph.start(startTime - graph.batchDuration)
     timer.start(startTime.milliseconds)
-    logInfo("JobGenerator started at " + startTime)
+    logInfo("Started JobGenerator at " + startTime)
   }
 
   /** Restarts the generator based on the information in checkpoint */
@@ -152,15 +206,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
 
     // Restart the timer
     timer.start(restartTime.milliseconds)
-    logInfo("JobGenerator restarted at " + restartTime)
+    logInfo("Restarted JobGenerator at " + restartTime)
   }
 
   /** Generate jobs and perform checkpoint for the given `time`.  */
   private def generateJobs(time: Time) {
     SparkEnv.set(ssc.env)
     Try(graph.generateJobs(time)) match {
-      case Success(jobs) => jobScheduler.runJobs(time, jobs)
-      case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e)
+      case Success(jobs) =>
+        jobScheduler.runJobs(time, jobs)
+      case Failure(e) =>
+        jobScheduler.reportError("Error generating jobs for time " + time, e)
     }
     eventActor ! DoCheckpoint(time)
   }
@@ -168,20 +224,32 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
   /** Clear DStream metadata for the given `time`. */
   private def clearMetadata(time: Time) {
     ssc.graph.clearMetadata(time)
-    eventActor ! DoCheckpoint(time)
+
+    // If checkpointing is enabled, then checkpoint,
+    // else mark batch to be fully processed
+    if (shouldCheckpoint) {
+      eventActor ! DoCheckpoint(time)
+    } else {
+      markBatchFullyProcessed(time)
+    }
   }
 
   /** Clear DStream checkpoint data for the given `time`. */
   private def clearCheckpointData(time: Time) {
     ssc.graph.clearCheckpointData(time)
+    markBatchFullyProcessed(time)
   }
 
   /** Perform checkpoint for the give `time`. */
-  private def doCheckpoint(time: Time) = synchronized {
-    if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
+  private def doCheckpoint(time: Time) {
+    if (shouldCheckpoint && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
       logInfo("Checkpointing graph for time " + time)
       ssc.graph.updateCheckpointData(time)
       checkpointWriter.write(new Checkpoint(ssc, time))
     }
   }
+
+  private def markBatchFullyProcessed(time: Time) {
+    lastProcessedBatch = time
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index de675d3..04e0a6a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -39,7 +39,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
 
   private val jobSets = new ConcurrentHashMap[Time, JobSet]
   private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1)
-  private val executor = Executors.newFixedThreadPool(numConcurrentJobs)
+  private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs)
   private val jobGenerator = new JobGenerator(this)
   val clock = jobGenerator.clock
   val listenerBus = new StreamingListenerBus()
@@ -50,36 +50,54 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
   private var eventActor: ActorRef = null
 
 
-  def start() = synchronized {
-    if (eventActor != null) {
-      throw new SparkException("JobScheduler already started")
-    }
+  def start(): Unit = synchronized {
+    if (eventActor != null) return // scheduler has already been started
 
+    logDebug("Starting JobScheduler")
     eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
       def receive = {
         case event: JobSchedulerEvent => processEvent(event)
       }
     }), "JobScheduler")
+
     listenerBus.start()
     networkInputTracker = new NetworkInputTracker(ssc)
     networkInputTracker.start()
-    Thread.sleep(1000)
     jobGenerator.start()
-    logInfo("JobScheduler started")
+    logInfo("Started JobScheduler")
   }
 
-  def stop() = synchronized {
-    if (eventActor != null) {
-      jobGenerator.stop()
-      networkInputTracker.stop()
-      executor.shutdown()
-      if (!executor.awaitTermination(2, TimeUnit.SECONDS)) {
-        executor.shutdownNow()
-      }
-      listenerBus.stop()
-      ssc.env.actorSystem.stop(eventActor)
-      logInfo("JobScheduler stopped")
+  def stop(processAllReceivedData: Boolean): Unit = synchronized {
+    if (eventActor == null) return // scheduler has already been stopped
+    logDebug("Stopping JobScheduler")
+
+    // First, stop receiving
+    networkInputTracker.stop()
+
+    // Second, stop generating jobs. If it has to process all received data,
+    // then this will wait for all the processing through JobScheduler to be over.
+    jobGenerator.stop(processAllReceivedData)
+
+    // Stop the executor for receiving new jobs
+    logDebug("Stopping job executor")
+    jobExecutor.shutdown()
+
+    // Wait for the queued jobs to complete if indicated
+    val terminated = if (processAllReceivedData) {
+      jobExecutor.awaitTermination(1, TimeUnit.HOURS)  // just a very large period of time
+    } else {
+      jobExecutor.awaitTermination(2, TimeUnit.SECONDS)
     }
+    if (!terminated) {
+      jobExecutor.shutdownNow()
+    }
+    logDebug("Stopped job executor")
+
+    // Stop everything else
+    listenerBus.stop()
+    ssc.env.actorSystem.stop(eventActor)
+    eventActor = null
+    logInfo("Stopped JobScheduler")
   }
 
   def runJobs(time: Time, jobs: Seq[Job]) {
@@ -88,7 +106,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
     } else {
       val jobSet = new JobSet(time, jobs)
       jobSets.put(time, jobSet)
-      jobSet.jobs.foreach(job => executor.execute(new JobHandler(job)))
+      jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job)))
       logInfo("Added jobs for time " + time)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
index cad68e2..067e804 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
@@ -17,20 +17,14 @@
 
 package org.apache.spark.streaming.scheduler
 
-import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
-import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
-import org.apache.spark.{SparkException, Logging, SparkEnv}
-import org.apache.spark.SparkContext._
-
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.Queue
-import scala.concurrent.duration._
+import scala.collection.mutable.{HashMap, Queue, SynchronizedMap}
 
 import akka.actor._
-import akka.pattern.ask
-import akka.dispatch._
+import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.spark.SparkContext._
 import org.apache.spark.storage.BlockId
-import org.apache.spark.streaming.{Time, StreamingContext}
+import org.apache.spark.streaming.{StreamingContext, Time}
+import org.apache.spark.streaming.dstream.{NetworkReceiver, StopReceiver}
 import org.apache.spark.util.AkkaUtils
 
 private[streaming] sealed trait NetworkInputTrackerMessage
@@ -52,8 +46,8 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
   val networkInputStreams = ssc.graph.getNetworkInputStreams()
   val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
   val receiverExecutor = new ReceiverExecutor()
-  val receiverInfo = new HashMap[Int, ActorRef]
-  val receivedBlockIds = new HashMap[Int, Queue[BlockId]]
+  val receiverInfo = new HashMap[Int, ActorRef] with SynchronizedMap[Int, ActorRef]
+  val receivedBlockIds = new HashMap[Int, Queue[BlockId]] with SynchronizedMap[Int, Queue[BlockId]]
   val timeout = AkkaUtils.askTimeout(ssc.conf)
 
 
@@ -63,7 +57,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
   var currentTime: Time = null
 
   /** Start the actor and receiver execution thread. */
-  def start() {
+  def start() = synchronized {
     if (actor != null) {
       throw new SparkException("NetworkInputTracker already started")
     }
@@ -77,72 +71,99 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
   }
 
   /** Stop the receiver execution thread. */
-  def stop() {
+  def stop() = synchronized {
     if (!networkInputStreams.isEmpty && actor != null) {
-      receiverExecutor.interrupt()
-      receiverExecutor.stopReceivers()
+      // First, stop the receivers
+      receiverExecutor.stop()
+
+      // Finally, stop the actor
       ssc.env.actorSystem.stop(actor)
+      actor = null
       logInfo("NetworkInputTracker stopped")
     }
   }
 
-  /** Return all the blocks received from a receiver. */
-  def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized {
-    val queue =  receivedBlockIds.synchronized {
-      receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]())
+  /** Register a receiver */
+  def registerReceiver(streamId: Int, receiverActor: ActorRef, sender: ActorRef) {
+    if (!networkInputStreamMap.contains(streamId)) {
+      throw new Exception("Register received for unexpected id " + streamId)
     }
-    val result = queue.synchronized {
-      queue.dequeueAll(x => true)
-    }
-    logInfo("Stream " + receiverId + " received " + result.size + " blocks")
-    result.toArray
+    receiverInfo += ((streamId, receiverActor))
+    logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address)
+  }
+
+  /** Deregister a receiver */
+  def deregisterReceiver(streamId: Int, message: String) {
+    receiverInfo -= streamId
+    logError("Deregistered receiver for network stream " + streamId + " with message:\n" + message)
+  }
+
+  /** Get all the received blocks for the given stream. */
+  def getBlocks(streamId: Int, time: Time): Array[BlockId] = {
+    val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId]())
+    val result = queue.dequeueAll(x => true).toArray
+    logInfo("Stream " + streamId + " received " + result.size + " blocks")
+    result
+  }
+
+  /** Add new blocks for the given stream */
+  def addBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) = {
+    val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId])
+    queue ++= blockIds
+    networkInputStreamMap(streamId).addMetadata(metadata)
+    logDebug("Stream " + streamId + " received new blocks: " + blockIds.mkString("[", ", ", "]"))
+  }
+
+  /** Check if any blocks are left to be processed */
+  def hasMoreReceivedBlockIds: Boolean = {
+    !receivedBlockIds.forall(_._2.isEmpty)
   }
 
   /** Actor to receive messages from the receivers. */
   private class NetworkInputTrackerActor extends Actor {
     def receive = {
-      case RegisterReceiver(streamId, receiverActor) => {
-        if (!networkInputStreamMap.contains(streamId)) {
-          throw new Exception("Register received for unexpected id " + streamId)
-        }
-        receiverInfo += ((streamId, receiverActor))
-        logInfo("Registered receiver for network stream " + streamId + " from "
-          + sender.path.address)
+      case RegisterReceiver(streamId, receiverActor) =>
+        registerReceiver(streamId, receiverActor, sender)
+        sender ! true
+      case AddBlocks(streamId, blockIds, metadata) =>
+        addBlocks(streamId, blockIds, metadata)
+      case DeregisterReceiver(streamId, message) =>
+        deregisterReceiver(streamId, message)
         sender ! true
-      }
-      case AddBlocks(streamId, blockIds, metadata) => {
-        val tmp = receivedBlockIds.synchronized {
-          if (!receivedBlockIds.contains(streamId)) {
-            receivedBlockIds += ((streamId, new Queue[BlockId]))
-          }
-          receivedBlockIds(streamId)
-        }
-        tmp.synchronized {
-          tmp ++= blockIds
-        }
-        networkInputStreamMap(streamId).addMetadata(metadata)
-      }
-      case DeregisterReceiver(streamId, msg) => {
-        receiverInfo -= streamId
-        logError("De-registered receiver for network stream " + streamId
-          + " with message " + msg)
-        // TODO: Do something about the corresponding NetworkInputDStream
-      }
     }
   }
 
   /** This thread class runs all the receivers on the cluster.  */
-  class ReceiverExecutor extends Thread {
-    val env = ssc.env
-
-    override def run() {
-      try {
-        SparkEnv.set(env)
-        startReceivers()
-      } catch {
-        case ie: InterruptedException => logInfo("ReceiverExecutor interrupted")
-      } finally {
-        stopReceivers()
+  class ReceiverExecutor {
+    @transient val env = ssc.env
+    @transient val thread  = new Thread() {
+      override def run() {
+        try {
+          SparkEnv.set(env)
+          startReceivers()
+        } catch {
+          case ie: InterruptedException => logInfo("ReceiverExecutor interrupted")
+        }
+      }
+    }
+
+    def start() {
+      thread.start()
+    }
+
+    def stop() {
+      // Send the stop signal to all the receivers
+      stopReceivers()
+
+      // Wait for the Spark job that runs the receivers to be over
+      // That is, for the receivers to quit gracefully.
+      thread.join(10000)
+
+      // Check if all the receivers have been deregistered or not
+      if (!receiverInfo.isEmpty) {
+        logWarning("All of the receivers have not deregistered, " + receiverInfo)
+      } else {
+        logInfo("All of the receivers have deregistered successfully")
       }
     }
 
@@ -150,7 +171,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
      * Get the receivers from the NetworkInputDStreams, distributes them to the
      * worker nodes as a parallel collection, and runs them.
      */
-    def startReceivers() {
+    private def startReceivers() {
       val receivers = networkInputStreams.map(nis => {
         val rcvr = nis.getReceiver()
         rcvr.setStreamId(nis.id)
@@ -186,13 +207,16 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
       }
 
       // Distribute the receivers and start them
+      logInfo("Starting " + receivers.length + " receivers")
       ssc.sparkContext.runJob(tempRDD, startReceiver)
+      logInfo("All of the receivers have been terminated")
     }
 
     /** Stops the receivers. */
-    def stopReceivers() {
+    private def stopReceivers() {
       // Signal the receivers to stop
       receiverInfo.values.foreach(_ ! StopReceiver)
+      logInfo("Sent stop signal to all " + receiverInfo.size + " receivers")
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
index c3a849d..c5ef2cc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
@@ -48,14 +48,11 @@ class SystemClock() extends Clock {
         minPollTime 
       }  
     }
-    
-    
+
     while (true) {
       currentTime = System.currentTimeMillis()
       waitTime = targetTime - currentTime
-      
       if (waitTime <= 0) {
-        
         return currentTime
       }
       val sleepTime = 

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
index 559c247..f71938a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
@@ -17,44 +17,84 @@
 
 package org.apache.spark.streaming.util
 
+import org.apache.spark.Logging
+
 private[streaming]
-class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) {
+class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String)
+  extends Logging {
   
-  private val thread = new Thread("RecurringTimer") {
+  private val thread = new Thread("RecurringTimer - " + name) {
+    setDaemon(true)
     override def run() { loop }    
   }
-  
-  private var nextTime = 0L
 
+  @volatile private var prevTime = -1L
+  @volatile private var nextTime = -1L
+  @volatile private var stopped = false
+
+  /**
+   * Get the time when this timer will fire if it is started right now.
+   * The time will be a multiple of this timer's period and more than
+   * current system time.
+   */
   def getStartTime(): Long = {
     (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period
   }
 
+  /**
+   * Get the time when the timer will fire if it is restarted right now.
+   * This time depends on when the timer was started the first time, and was stopped
+   * for whatever reason. The time must be a multiple of this timer's period and
+   * more than current time.
+   */
   def getRestartTime(originalStartTime: Long): Long = {
     val gap = clock.currentTime - originalStartTime
     (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime
   }
 
-  def start(startTime: Long): Long = {
+  /**
+   * Start at the given start time.
+   */
+  def start(startTime: Long): Long = synchronized {
     nextTime = startTime
     thread.start()
+    logInfo("Started timer for " + name + " at time " + nextTime)
     nextTime
   }
 
+  /**
+   * Start at the earliest time it can start based on the period.
+   */
   def start(): Long = {
     start(getStartTime())
   }
 
-  def stop() {
-    thread.interrupt() 
+  /**
+   * Stop the timer, and return the last time the callback was made.
+   * interruptTimer = true will interrupt the callback
+   * if it is in progress (not guaranteed to give correct time in this case).
+   */
+  def stop(interruptTimer: Boolean): Long = synchronized {
+    if (!stopped) {
+      stopped = true
+      if (interruptTimer) thread.interrupt()
+      thread.join()
+      logInfo("Stopped timer for " + name + " after time " + prevTime)
+    }
+    prevTime
   }
-  
+
+  /**
+   * Repeatedly call the callback every interval.
+   */
   private def loop() {
     try {
-      while (true) {
+      while (!stopped) {
         clock.waitTillTime(nextTime)
         callback(nextTime)
+        prevTime = nextTime
         nextTime += period
+        logDebug("Callback for " + name + " called at time " + prevTime)
       }
     } catch {
       case e: InterruptedException =>
@@ -74,10 +114,10 @@ object RecurringTimer {
       println("" + currentTime + ": " + (currentTime - lastRecurTime))
       lastRecurTime = currentTime
     }
-    val timer = new  RecurringTimer(new SystemClock(), period, onRecur)
+    val timer = new  RecurringTimer(new SystemClock(), period, onRecur, "Test")
     timer.start()
     Thread.sleep(30 * 1000)
-    timer.stop()
+    timer.stop(true)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index bcb0c28..bb73dbf 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -324,7 +324,7 @@ class BasicOperationsSuite extends TestSuiteBase {
 
     val updateStateOperation = (s: DStream[String]) => {
       val updateFunc = (values: Seq[Int], state: Option[Int]) => {
-        Some(values.foldLeft(0)(_ + _) + state.getOrElse(0))
+        Some(values.sum + state.getOrElse(0))
       }
       s.map(x => (x, 1)).updateStateByKey[Int](updateFunc)
     }
@@ -359,7 +359,7 @@ class BasicOperationsSuite extends TestSuiteBase {
       // updateFunc clears a state when a StateObject is seen without new values twice in a row
       val updateFunc = (values: Seq[Int], state: Option[StateObject]) => {
         val stateObj = state.getOrElse(new StateObject)
-        values.foldLeft(0)(_ + _) match {
+        values.sum match {
           case 0 => stateObj.expireCounter += 1 // no new values
           case n => { // has new values, increment and reset expireCounter
             stateObj.counter += n

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 717da8e..9cc27ef 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -17,19 +17,22 @@
 
 package org.apache.spark.streaming
 
-import org.scalatest.{FunSuite, BeforeAndAfter}
-import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.dstream.{DStream, NetworkReceiver}
+import org.apache.spark.util.{MetadataCleaner, Utils}
+import org.scalatest.{BeforeAndAfter, FunSuite}
 import org.scalatest.concurrent.Timeouts
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
 import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SparkException, SparkConf, SparkContext}
-import org.apache.spark.util.{Utils, MetadataCleaner}
-import org.apache.spark.streaming.dstream.DStream
 
-class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
+class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging {
 
   val master = "local[2]"
   val appName = this.getClass.getSimpleName
-  val batchDuration = Seconds(1)
+  val batchDuration = Milliseconds(500)
   val sparkHome = "someDir"
   val envPair = "key" -> "value"
   val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100
@@ -108,19 +111,31 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
     val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
     myConf.set("spark.cleaner.ttl", ttl.toString)
     val ssc1 = new StreamingContext(myConf, batchDuration)
+    addInputStream(ssc1).register
+    ssc1.start()
     val cp = new Checkpoint(ssc1, Time(1000))
     assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl)
     ssc1.stop()
     val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
     assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl)
-    ssc = new StreamingContext(null, cp, null)
+    ssc = new StreamingContext(null, newCp, null)
     assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl)
   }
 
-  test("start multiple times") {
+  test("start and stop state check") {
     ssc = new StreamingContext(master, appName, batchDuration)
     addInputStream(ssc).register
 
+    assert(ssc.state === ssc.StreamingContextState.Initialized)
+    ssc.start()
+    assert(ssc.state === ssc.StreamingContextState.Started)
+    ssc.stop()
+    assert(ssc.state === ssc.StreamingContextState.Stopped)
+  }
+
+  test("start multiple times") {
+    ssc = new StreamingContext(master, appName, batchDuration)
+    addInputStream(ssc).register
     ssc.start()
     intercept[SparkException] {
       ssc.start()
@@ -133,18 +148,61 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
     ssc.start()
     ssc.stop()
     ssc.stop()
-    ssc = null
   }
 
+  test("stop before start and start after stop") {
+    ssc = new StreamingContext(master, appName, batchDuration)
+    addInputStream(ssc).register
+    ssc.stop()  // stop before start should not throw exception
+    ssc.start()
+    ssc.stop()
+    intercept[SparkException] {
+      ssc.start() // start after stop should throw exception
+    }
+  }
+
+
   test("stop only streaming context") {
     ssc = new StreamingContext(master, appName, batchDuration)
     sc = ssc.sparkContext
     addInputStream(ssc).register
     ssc.start()
     ssc.stop(false)
-    ssc = null
     assert(sc.makeRDD(1 to 100).collect().size === 100)
     ssc = new StreamingContext(sc, batchDuration)
+    addInputStream(ssc).register
+    ssc.start()
+    ssc.stop()
+  }
+
+  test("stop gracefully") {
+    val conf = new SparkConf().setMaster(master).setAppName(appName)
+    conf.set("spark.cleaner.ttl", "3600")
+    sc = new SparkContext(conf)
+    for (i <- 1 to 4) {
+      logInfo("==================================")
+      ssc = new StreamingContext(sc, batchDuration)
+      var runningCount = 0
+      TestReceiver.counter.set(1)
+      val input = ssc.networkStream(new TestReceiver)
+      input.count.foreachRDD(rdd => {
+        val count = rdd.first()
+        logInfo("Count = " + count)
+        runningCount += count.toInt
+      })
+      ssc.start()
+      ssc.awaitTermination(500)
+      ssc.stop(stopSparkContext = false, stopGracefully = true)
+      logInfo("Running count = " + runningCount)
+      logInfo("TestReceiver.counter = " + TestReceiver.counter.get())
+      assert(runningCount > 0)
+      assert(
+        (TestReceiver.counter.get() == runningCount + 1) ||
+          (TestReceiver.counter.get() == runningCount + 2),
+        "Received records = " + TestReceiver.counter.get() + ", " +
+          "processed records = " + runningCount
+      )
+    }
   }
 
   test("awaitTermination") {
@@ -199,7 +257,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
   test("awaitTermination with error in job generation") {
     ssc = new StreamingContext(master, appName, batchDuration)
     val inputStream = addInputStream(ssc)
-
     inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register
     val exception = intercept[TestException] {
       ssc.start()
@@ -215,4 +272,29 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
   }
 }
 
-class TestException(msg: String) extends Exception(msg)
\ No newline at end of file
+class TestException(msg: String) extends Exception(msg)
+
+/** Custom receiver for testing whether all data received by a receiver gets processed or not */
+class TestReceiver extends NetworkReceiver[Int] {
+  protected lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
+  protected def onStart() {
+    blockGenerator.start()
+    logInfo("BlockGenerator started on thread " + receivingThread)
+    try {
+      while(true) {
+        blockGenerator += TestReceiver.counter.getAndIncrement
+        Thread.sleep(0)
+      }
+    } finally {
+      logInfo("Receiving stopped at count value of " + TestReceiver.counter.get())
+    }
+  }
+
+  protected def onStop() {
+    blockGenerator.stop()
+  }
+}
+
+object TestReceiver {
+  val counter = new AtomicInteger(1)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/83ac9a4b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 2016306..aa2d5c2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -277,7 +277,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
       assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
       assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
 
-      Thread.sleep(500) // Give some time for the forgetting old RDDs to complete
+      Thread.sleep(100) // Give some time for the forgetting old RDDs to complete
     } catch {
       case e: Exception => {e.printStackTrace(); throw e}
     } finally {