You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2015/05/02 02:46:10 UTC

spark git commit: [SPARK-7112][Streaming][WIP] Add a InputInfoTracker to track all the input streams

Repository: spark
Updated Branches:
  refs/heads/master ebc25a4dd -> b88c275e6


[SPARK-7112][Streaming][WIP] Add a InputInfoTracker to track all the input streams

Author: jerryshao <sa...@intel.com>
Author: Saisai Shao <sa...@intel.com>

Closes #5680 from jerryshao/SPARK-7111 and squashes the following commits:

339f854 [Saisai Shao] Add an end-to-end test
812bcaf [jerryshao] Continue address the comments
abd0036 [jerryshao] Address the comments
727264e [jerryshao] Fix comment typo
6682bef [jerryshao] Fix compile issue
8325787 [jerryshao] Fix rebase issue
17fa251 [jerryshao] Refactor to build InputInfoTracker
ee1b536 [jerryshao] Add DirectStreamTracker to track the direct streams


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b88c275e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b88c275e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b88c275e

Branch: refs/heads/master
Commit: b88c275e6ef6b17cd34d1c2c780b8959b41222c0
Parents: ebc25a4
Author: jerryshao <sa...@intel.com>
Authored: Fri May 1 17:46:06 2015 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Fri May 1 17:46:06 2015 -0700

----------------------------------------------------------------------
 .../spark/streaming/StreamingContext.scala      |  4 +-
 .../spark/streaming/dstream/InputDStream.scala  |  3 +
 .../dstream/ReceiverInputDStream.scala          |  9 ++-
 .../spark/streaming/scheduler/BatchInfo.scala   |  8 +-
 .../streaming/scheduler/InputInfoTracker.scala  | 62 +++++++++++++++
 .../streaming/scheduler/JobGenerator.scala      |  8 +-
 .../streaming/scheduler/JobScheduler.scala      |  4 +
 .../spark/streaming/scheduler/JobSet.scala      |  4 +-
 .../apache/spark/streaming/ui/BatchUIData.scala |  2 +-
 .../ui/StreamingJobProgressListener.scala       | 31 ++++----
 .../spark/streaming/ui/StreamingPage.scala      |  4 +-
 .../spark/streaming/InputStreamsSuite.scala     | 33 +++++++-
 .../streaming/StreamingListenerSuite.scala      | 15 ++++
 .../apache/spark/streaming/TestSuiteBase.scala  |  8 +-
 .../scheduler/InputInfoTrackerSuite.scala       | 79 ++++++++++++++++++++
 .../ui/StreamingJobProgressListenerSuite.scala  | 19 ++---
 16 files changed, 247 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 90c8b47..117cb59 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -159,7 +159,7 @@ class StreamingContext private[streaming] (
     }
   }
 
-  private val nextReceiverInputStreamId = new AtomicInteger(0)
+  private val nextInputStreamId = new AtomicInteger(0)
 
   private[streaming] var checkpointDir: String = {
     if (isCheckpointPresent) {
@@ -241,7 +241,7 @@ class StreamingContext private[streaming] (
     if (isCheckpointPresent) cp_ else null
   }
 
-  private[streaming] def getNewReceiverStreamId() = nextReceiverInputStreamId.getAndIncrement()
+  private[streaming] def getNewInputStreamId() = nextInputStreamId.getAndIncrement()
 
   /**
    * Create an input stream with any arbitrary user implemented receiver.

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index e652702..e4ad4b5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -41,6 +41,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
 
   ssc.graph.addInputStream(this)
 
+  /** This is an unique identifier for the input stream. */
+  val id = ssc.getNewInputStreamId()
+
   /**
    * Checks whether the 'time' is valid wrt slideDuration for generating RDD.
    * Additionally it also ensures valid times are in strictly increasing order.

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index 4c7fd2c..ba88416 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -24,7 +24,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel}
 import org.apache.spark.streaming._
 import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD
 import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult}
-import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
+import org.apache.spark.streaming.scheduler.{InputInfo, ReceivedBlockInfo}
 
 /**
  * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]]
@@ -39,9 +39,6 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
 abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext)
   extends InputDStream[T](ssc_) {
 
-  /** This is an unique identifier for the receiver input stream. */
-  val id = ssc.getNewReceiverStreamId()
-
   /**
    * Gets the receiver object that will be sent to the worker nodes
    * to receive data. This method needs to defined by any specific implementation
@@ -72,6 +69,10 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
         val blockStoreResults = blockInfos.map { _.blockStoreResult }
         val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray
 
+        // Register the input blocks information into InputInfoTracker
+        val inputInfo = InputInfo(id, blockInfos.map(_.numRecords).sum)
+        ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
+
         // Check whether all the results are of the same type
         val resultTypes = blockStoreResults.map { _.getClass }.distinct
         if (resultTypes.size > 1) {

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
index 92dc113..5b9bfbf 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
@@ -24,6 +24,7 @@ import org.apache.spark.streaming.Time
  * :: DeveloperApi ::
  * Class having information on completed batches.
  * @param batchTime   Time of the batch
+ * @param streamIdToNumRecords A map of input stream id to record number
  * @param submissionTime  Clock time of when jobs of this batch was submitted to
  *                        the streaming scheduler queue
  * @param processingStartTime Clock time of when the first job of this batch started processing
@@ -32,7 +33,7 @@ import org.apache.spark.streaming.Time
 @DeveloperApi
 case class BatchInfo(
     batchTime: Time,
-    receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]],
+    streamIdToNumRecords: Map[Int, Long],
     submissionTime: Long,
     processingStartTime: Option[Long],
     processingEndTime: Option[Long]
@@ -58,4 +59,9 @@ case class BatchInfo(
    */
   def totalDelay: Option[Long] = schedulingDelay.zip(processingDelay)
     .map(x => x._1 + x._2).headOption
+
+  /**
+   * The number of recorders received by the receivers in this batch.
+   */
+  def numRecords: Long = streamIdToNumRecords.values.sum
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala
new file mode 100644
index 0000000..a72efcc
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+import org.apache.spark.streaming.{Time, StreamingContext}
+
+/** To track the information of input stream at specified batch time. */
+private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long)
+
+/**
+ * This class manages all the input streams as well as their input data statistics. The information
+ * will be exposed through StreamingListener for monitoring.
+ */
+private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging {
+
+  // Map to track all the InputInfo related to specific batch time and input stream.
+  private val batchTimeToInputInfos = new mutable.HashMap[Time, mutable.HashMap[Int, InputInfo]]
+
+  /** Report the input information with batch time to the tracker */
+  def reportInfo(batchTime: Time, inputInfo: InputInfo): Unit = synchronized {
+    val inputInfos = batchTimeToInputInfos.getOrElseUpdate(batchTime,
+      new mutable.HashMap[Int, InputInfo]())
+
+    if (inputInfos.contains(inputInfo.inputStreamId)) {
+      throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" +
+        s"$batchTime is already added into InputInfoTracker, this is a illegal state")
+    }
+    inputInfos += ((inputInfo.inputStreamId, inputInfo))
+  }
+
+  /** Get the all the input stream's information of specified batch time */
+  def getInfo(batchTime: Time): Map[Int, InputInfo] = synchronized {
+    val inputInfos = batchTimeToInputInfos.get(batchTime)
+    // Convert mutable HashMap to immutable Map for the caller
+    inputInfos.map(_.toMap).getOrElse(Map[Int, InputInfo]())
+  }
+
+  /** Cleanup the tracked input information older than threshold batch time */
+  def cleanup(batchThreshTime: Time): Unit = synchronized {
+    val timesToCleanup = batchTimeToInputInfos.keys.filter(_ < batchThreshTime)
+    logInfo(s"remove old batch metadata: ${timesToCleanup.mkString(" ")}")
+    batchTimeToInputInfos --= timesToCleanup
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 2467d50..9f93d6c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -243,9 +243,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
       graph.generateJobs(time) // generate jobs using allocated block
     } match {
       case Success(jobs) =>
-        val receivedBlockInfos =
-          jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray }
-        jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos))
+        val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time)
+        val streamIdToNumRecords = streamIdToInputInfos.mapValues(_.numRecords)
+        jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToNumRecords))
       case Failure(e) =>
         jobScheduler.reportError("Error generating jobs for time " + time, e)
     }
@@ -266,6 +266,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
       // checkpointing of this batch to complete.
       val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
       jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
+      jobScheduler.inputInfoTracker.cleanup(time - maxRememberDuration)
       markBatchFullyProcessed(time)
     }
   }
@@ -278,6 +279,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
     // been saved to checkpoints, so its safe to delete block metadata and data WAL files
     val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
     jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
+    jobScheduler.inputInfoTracker.cleanup(time - maxRememberDuration)
     markBatchFullyProcessed(time)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index c7a2c11..1d1ddaa 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -50,6 +50,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
   // These two are created only when scheduler starts.
   // eventLoop not being null means the scheduler has been started and not stopped
   var receiverTracker: ReceiverTracker = null
+  // A tracker to track all the input stream information as well as processed record number
+  var inputInfoTracker: InputInfoTracker = null
+
   private var eventLoop: EventLoop[JobSchedulerEvent] = null
 
   def start(): Unit = synchronized {
@@ -65,6 +68,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
 
     listenerBus.start(ssc.sparkContext)
     receiverTracker = new ReceiverTracker(ssc)
+    inputInfoTracker = new InputInfoTracker(ssc)
     receiverTracker.start()
     jobGenerator.start()
     logInfo("Started JobScheduler")

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
index 24b3794..e6be63b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -28,7 +28,7 @@ private[streaming]
 case class JobSet(
     time: Time,
     jobs: Seq[Job],
-    receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty) {
+    streamIdToNumRecords: Map[Int, Long] = Map.empty) {
 
   private val incompleteJobs = new HashSet[Job]()
   private val submissionTime = System.currentTimeMillis() // when this jobset was submitted
@@ -64,7 +64,7 @@ case class JobSet(
   def toBatchInfo: BatchInfo = {
     new BatchInfo(
       time,
-      receivedBlockInfo,
+      streamIdToNumRecords,
       submissionTime,
       if (processingStartTime >= 0 ) Some(processingStartTime) else None,
       if (processingEndTime >= 0 ) Some(processingEndTime) else None

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala
index f45c291..99e10d2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala
@@ -66,7 +66,7 @@ private[ui] object BatchUIData {
   def apply(batchInfo: BatchInfo): BatchUIData = {
     new BatchUIData(
       batchInfo.batchTime,
-      batchInfo.receivedBlockInfo.mapValues(_.map(_.numRecords).sum),
+      batchInfo.streamIdToNumRecords,
       batchInfo.submissionTime,
       batchInfo.processingStartTime,
       batchInfo.processingEndTime

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
index 34b5571..d2729fa 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
@@ -188,25 +188,26 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
   }
 
   def receivedRecordsDistributions: Map[Int, Option[Distribution]] = synchronized {
-    val latestBatches = retainedBatches.reverse.take(batchUIDataLimit)
-    (0 until numReceivers).map { receiverId =>
-      val recordsOfParticularReceiver = latestBatches.map { batch =>
-        // calculate records per second for each batch
-        batch.receiverNumRecords.get(receiverId).sum.toDouble * 1000 / batchDuration
-      }
-      val distributionOption = Distribution(recordsOfParticularReceiver)
-      (receiverId, distributionOption)
+    val latestBatchInfos = retainedBatches.reverse.take(batchUIDataLimit)
+    val latestReceiverNumRecords = latestBatchInfos.map(_.receiverNumRecords)
+    val streamIds = ssc.graph.getInputStreams().map(_.id)
+    streamIds.map { id =>
+     val recordsOfParticularReceiver =
+       latestReceiverNumRecords.map(v => v.getOrElse(id, 0L).toDouble * 1000 / batchDuration)
+      val distribution = Distribution(recordsOfParticularReceiver)
+      (id, distribution)
     }.toMap
   }
 
   def lastReceivedBatchRecords: Map[Int, Long] = synchronized {
-    val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receiverNumRecords)
-    lastReceivedBlockInfoOption.map { lastReceivedBlockInfo =>
-      (0 until numReceivers).map { receiverId =>
-        (receiverId, lastReceivedBlockInfo.getOrElse(receiverId, 0L))
+    val lastReceiverNumRecords = lastReceivedBatch.map(_.receiverNumRecords)
+    val streamIds = ssc.graph.getInputStreams().map(_.id)
+    lastReceiverNumRecords.map { receiverNumRecords =>
+      streamIds.map { id =>
+        (id, receiverNumRecords.getOrElse(id, 0L))
       }.toMap
     }.getOrElse {
-      (0 until numReceivers).map(receiverId => (receiverId, 0L)).toMap
+      streamIds.map(id => (id, 0L)).toMap
     }
   }
 
@@ -214,6 +215,10 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
     receiverInfos.get(receiverId)
   }
 
+  def receiverIds(): Iterable[Int] = synchronized {
+    receiverInfos.keys
+  }
+
   def lastCompletedBatch: Option[BatchUIData] = synchronized {
     completedBatchUIData.sortBy(_.batchTime)(Time.ordering).lastOption
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
index 07fa285..db37ae8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
@@ -95,7 +95,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
         "Maximum rate\n[events/sec]",
         "Last Error"
       )
-      val dataRows = (0 until listener.numReceivers).map { receiverId =>
+      val dataRows = listener.receiverIds().map { receiverId =>
         val receiverInfo = listener.receiverInfo(receiverId)
         val receiverName = receiverInfo.map(_.name).getOrElse(s"Receiver-$receiverId")
         val receiverActive = receiverInfo.map { info =>
@@ -114,7 +114,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
         }.getOrElse(emptyCell)
         Seq(receiverName, receiverActive, receiverLocation, receiverLastBatchRecords) ++
           receivedRecordStats ++ Seq(receiverLastError)
-      }
+      }.toSeq
       Some(listingTable(headerRow, dataRows))
     } else {
       None

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index e6ac497..eb13675 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -27,17 +27,18 @@ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQu
 import scala.language.postfixOps
 
 import com.google.common.io.Files
+import org.apache.hadoop.io.{Text, LongWritable}
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
+import org.apache.hadoop.fs.Path
 import org.scalatest.BeforeAndAfter
 import org.scalatest.concurrent.Eventually._
 
 import org.apache.spark.Logging
+import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
 import org.apache.spark.streaming.receiver.Receiver
-import org.apache.spark.rdd.RDD
-import org.apache.hadoop.io.{Text, LongWritable}
-import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
-import org.apache.hadoop.fs.Path
 
 class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
 
@@ -278,6 +279,30 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
     }
   }
 
+  test("test track the number of input stream") {
+    val ssc = new StreamingContext(conf, batchDuration)
+
+    class TestInputDStream extends InputDStream[String](ssc) {
+      def start() { }
+      def stop() { }
+      def compute(validTime: Time): Option[RDD[String]] = None
+    }
+
+    class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) {
+      def getReceiver: Receiver[String] = null
+    }
+
+    // Register input streams
+    val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream)
+    val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream)
+
+    assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length)
+    assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length)
+    assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams)
+    assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i))
+    assert(receiverInputStreams.map(_.id) === Array(0, 1))
+  }
+
   def testFileStream(newFilesOnly: Boolean) {
     val testDir: File = null
     try {

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index 9020be1..312cce4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -57,6 +57,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
       info.totalDelay should be (None)
     })
 
+    batchInfosSubmitted.foreach { info =>
+      info.numRecords should be (1L)
+      info.streamIdToNumRecords should be (Map(0 -> 1L))
+    }
+
     isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true)
 
     // SPARK-6766: processingStartTime of batch info should not be None when starting
@@ -70,6 +75,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
       info.totalDelay should be (None)
     })
 
+    batchInfosStarted.foreach { info =>
+      info.numRecords should be (1L)
+      info.streamIdToNumRecords should be (Map(0 -> 1L))
+    }
+
     isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true)
     isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true)
 
@@ -86,6 +96,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
       info.totalDelay.get should be >= 0L
     })
 
+    batchInfosCompleted.foreach { info =>
+      info.numRecords should be (1L)
+      info.streamIdToNumRecords should be (Map(0 -> 1L))
+    }
+
     isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true)
     isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true)
     isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true)

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index c3cae8a..2ba86ae 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -29,10 +29,10 @@ import org.scalatest.time.{Span, Seconds => ScalaTestSeconds}
 import org.scalatest.concurrent.Eventually.timeout
 import org.scalatest.concurrent.PatienceConfiguration
 
-import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
-import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener}
 import org.apache.spark.{SparkConf, Logging}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
+import org.apache.spark.streaming.scheduler._
 import org.apache.spark.util.{ManualClock, Utils}
 
 /**
@@ -57,6 +57,10 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]],
       return None
     }
 
+    // Report the input data's information to InputInfoTracker for testing
+    val inputInfo = InputInfo(id, selectedInput.length.toLong)
+    ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
+
     val rdd = ssc.sc.makeRDD(selectedInput, numPartitions)
     logInfo("Created RDD " + rdd.id + " with " + selectedInput)
     Some(rdd)

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala
new file mode 100644
index 0000000..5478b41
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.{Time, Duration, StreamingContext}
+
+class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter {
+
+  private var ssc: StreamingContext = _
+
+  before {
+    val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker")
+    if (ssc == null) {
+      ssc = new StreamingContext(conf, Duration(1000))
+    }
+  }
+
+  after {
+    if (ssc != null) {
+      ssc.stop()
+      ssc = null
+    }
+  }
+
+  test("test report and get InputInfo from InputInfoTracker") {
+    val inputInfoTracker = new InputInfoTracker(ssc)
+
+    val streamId1 = 0
+    val streamId2 = 1
+    val time = Time(0L)
+    val inputInfo1 = InputInfo(streamId1, 100L)
+    val inputInfo2 = InputInfo(streamId2, 300L)
+    inputInfoTracker.reportInfo(time, inputInfo1)
+    inputInfoTracker.reportInfo(time, inputInfo2)
+
+    val batchTimeToInputInfos = inputInfoTracker.getInfo(time)
+    assert(batchTimeToInputInfos.size == 2)
+    assert(batchTimeToInputInfos.keys === Set(streamId1, streamId2))
+    assert(batchTimeToInputInfos(streamId1) === inputInfo1)
+    assert(batchTimeToInputInfos(streamId2) === inputInfo2)
+    assert(inputInfoTracker.getInfo(time)(streamId1) === inputInfo1)
+  }
+
+  test("test cleanup InputInfo from InputInfoTracker") {
+    val inputInfoTracker = new InputInfoTracker(ssc)
+
+    val streamId1 = 0
+    val inputInfo1 = InputInfo(streamId1, 100L)
+    val inputInfo2 = InputInfo(streamId1, 300L)
+    inputInfoTracker.reportInfo(Time(0), inputInfo1)
+    inputInfoTracker.reportInfo(Time(1), inputInfo2)
+
+    inputInfoTracker.cleanup(Time(0))
+    assert(inputInfoTracker.getInfo(Time(0))(streamId1) === inputInfo1)
+    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)
+
+    inputInfoTracker.cleanup(Time(1))
+    assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None)
+    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b88c275e/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
index fa89536..e874536 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
@@ -49,13 +49,10 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
     val ssc = setupStreams(input, operation)
     val listener = new StreamingJobProgressListener(ssc)
 
-    val receivedBlockInfo = Map(
-      0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)),
-      1 -> Array(ReceivedBlockInfo(1, 300, null))
-    )
+    val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L)
 
     // onBatchSubmitted
-    val batchInfoSubmitted = BatchInfo(Time(1000), receivedBlockInfo, 1000, None, None)
+    val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None)
     listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted))
     listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted)))
     listener.runningBatches should be (Nil)
@@ -67,7 +64,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
     listener.numTotalReceivedRecords should be (0)
 
     // onBatchStarted
-    val batchInfoStarted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None)
+    val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords,  1000, Some(2000), None)
     listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted))
     listener.waitingBatches should be (Nil)
     listener.runningBatches should be (List(BatchUIData(batchInfoStarted)))
@@ -106,7 +103,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
         OutputOpIdAndSparkJobId(1, 1))
 
     // onBatchCompleted
-    val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None)
+    val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None)
     listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted))
     listener.waitingBatches should be (Nil)
     listener.runningBatches should be (Nil)
@@ -144,11 +141,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
     val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100)
     val listener = new StreamingJobProgressListener(ssc)
 
-    val receivedBlockInfo = Map(
-      0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)),
-      1 -> Array(ReceivedBlockInfo(1, 300, null))
-    )
-    val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None)
+    val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L)
+
+    val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None)
 
     for(_ <- 0 until (limit + 10)) {
       listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org