You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by an...@apache.org on 2016/02/01 20:02:20 UTC

spark git commit: [SPARK-6847][CORE][STREAMING] Fix stack overflow issue when updateStateByKey is followed by a checkpointed dstream

Repository: spark
Updated Branches:
  refs/heads/master c1da4d421 -> 6075573a9


[SPARK-6847][CORE][STREAMING] Fix stack overflow issue when updateStateByKey is followed by a checkpointed dstream

Add a local property to indicate if checkpointing all RDDs that are marked with the checkpoint flag, and enable it in Streaming

Author: Shixiong Zhu <sh...@databricks.com>

Closes #10934 from zsxwing/recursive-checkpoint.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6075573a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6075573a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6075573a

Branch: refs/heads/master
Commit: 6075573a93176ee8c071888e4525043d9e73b061
Parents: c1da4d4
Author: Shixiong Zhu <sh...@databricks.com>
Authored: Mon Feb 1 11:02:17 2016 -0800
Committer: Andrew Or <an...@databricks.com>
Committed: Mon Feb 1 11:02:17 2016 -0800

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 19 ++++++
 .../org/apache/spark/CheckpointSuite.scala      | 21 ++++++
 .../streaming/scheduler/JobGenerator.scala      |  5 ++
 .../streaming/scheduler/JobScheduler.scala      |  7 +-
 .../spark/streaming/CheckpointSuite.scala       | 69 ++++++++++++++++++++
 5 files changed, 119 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6075573a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index be47172..e8157cf 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1542,6 +1542,15 @@ abstract class RDD[T: ClassTag](
 
   private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
 
+  // Whether to checkpoint all ancestor RDDs that are marked for checkpointing. By default,
+  // we stop as soon as we find the first such RDD, an optimization that allows us to write
+  // less data but is not safe for all workloads. E.g. in streaming we may checkpoint both
+  // an RDD and its parent in every batch, in which case the parent may never be checkpointed
+  // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847).
+  private val checkpointAllMarkedAncestors =
+    Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS))
+      .map(_.toBoolean).getOrElse(false)
+
   /** Returns the first parent RDD */
   protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
     dependencies.head.rdd.asInstanceOf[RDD[U]]
@@ -1585,6 +1594,13 @@ abstract class RDD[T: ClassTag](
       if (!doCheckpointCalled) {
         doCheckpointCalled = true
         if (checkpointData.isDefined) {
+          if (checkpointAllMarkedAncestors) {
+            // TODO We can collect all the RDDs that needs to be checkpointed, and then checkpoint
+            // them in parallel.
+            // Checkpoint parents first because our lineage will be truncated after we
+            // checkpoint ourselves
+            dependencies.foreach(_.rdd.doCheckpoint())
+          }
           checkpointData.get.checkpoint()
         } else {
           dependencies.foreach(_.rdd.doCheckpoint())
@@ -1704,6 +1720,9 @@ abstract class RDD[T: ClassTag](
  */
 object RDD {
 
+  private[spark] val CHECKPOINT_ALL_MARKED_ANCESTORS =
+    "spark.checkpoint.checkpointAllMarkedAncestors"
+
   // The following implicit functions were in SparkContext before 1.3 and users had to
   // `import SparkContext._` to enable them. Now we move them here to make the compiler find
   // them automatically. However, we still keep the old functions in SparkContext for backward

http://git-wip-us.apache.org/repos/asf/spark/blob/6075573a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 390764b..ce35856 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -512,6 +512,27 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
     assert(rdd.isCheckpointedAndMaterialized === true)
     assert(rdd.partitions.size === 0)
   }
+
+  runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean =>
+    testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true)
+    testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false)
+  }
+
+  private def testCheckpointAllMarkedAncestors(
+      reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = {
+    sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString)
+    try {
+      val rdd1 = sc.parallelize(1 to 10)
+      checkpoint(rdd1, reliableCheckpoint)
+      val rdd2 = rdd1.map(_ + 1)
+      checkpoint(rdd2, reliableCheckpoint)
+      rdd2.count()
+      assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors)
+      assert(rdd2.isCheckpointed === true)
+    } finally {
+      sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null)
+    }
+  }
 }
 
 /** RDD partition that has large serialized size. */

http://git-wip-us.apache.org/repos/asf/spark/blob/6075573a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index a5a01e7..a3ad5ea 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler
 import scala.util.{Failure, Success, Try}
 
 import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.rdd.RDD
 import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time}
 import org.apache.spark.streaming.util.RecurringTimer
 import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils}
@@ -243,6 +244,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
     // Example: BlockRDDs are created in this thread, and it needs to access BlockManager
     // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed.
     SparkEnv.set(ssc.env)
+
+    // Checkpoint all RDDs marked for checkpointing to ensure their lineages are
+    // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
+    ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
     Try {
       jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
       graph.generateJobs(time) // generate jobs using allocated block

http://git-wip-us.apache.org/repos/asf/spark/blob/6075573a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 9535c8e..3fed3d8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -23,10 +23,10 @@ import scala.collection.JavaConverters._
 import scala.util.Failure
 
 import org.apache.spark.Logging
-import org.apache.spark.rdd.PairRDDFunctions
+import org.apache.spark.rdd.{PairRDDFunctions, RDD}
 import org.apache.spark.streaming._
 import org.apache.spark.streaming.ui.UIUtils
-import org.apache.spark.util.{EventLoop, ThreadUtils, Utils}
+import org.apache.spark.util.{EventLoop, ThreadUtils}
 
 
 private[scheduler] sealed trait JobSchedulerEvent
@@ -210,6 +210,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
           s"""Streaming job from <a href="$batchUrl">$batchLinkText</a>""")
         ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString)
         ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString)
+        // Checkpoint all RDDs marked for checkpointing to ensure their lineages are
+        // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
+        ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
 
         // We need to assign `eventLoop` to a temp variable. Otherwise, because
         // `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then

http://git-wip-us.apache.org/repos/asf/spark/blob/6075573a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 4a6b91f..786703e 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -821,6 +821,75 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
     checkpointWriter.stop()
   }
 
+  test("SPARK-6847: stack overflow when updateStateByKey is followed by a checkpointed dstream") {
+    // In this test, there are two updateStateByKey operators. The RDD DAG is as follows:
+    //
+    //     batch 1            batch 2            batch 3     ...
+    //
+    // 1) input rdd          input rdd          input rdd
+    //       |                  |                  |
+    //       v                  v                  v
+    // 2) cogroup rdd   ---> cogroup rdd   ---> cogroup rdd  ...
+    //       |         /        |         /        |
+    //       v        /         v        /         v
+    // 3)  map rdd ---        map rdd ---        map rdd     ...
+    //       |                  |                  |
+    //       v                  v                  v
+    // 4) cogroup rdd   ---> cogroup rdd   ---> cogroup rdd  ...
+    //       |         /        |         /        |
+    //       v        /         v        /         v
+    // 5)  map rdd ---        map rdd ---        map rdd     ...
+    //
+    // Every batch depends on its previous batch, so "updateStateByKey" needs to do checkpoint to
+    // break the RDD chain. However, before SPARK-6847, when the state RDD (layer 5) of the second
+    // "updateStateByKey" does checkpoint, it won't checkpoint the state RDD (layer 3) of the first
+    // "updateStateByKey" (Note: "updateStateByKey" has already marked that its state RDD (layer 3)
+    // should be checkpointed). Hence, the connections between layer 2 and layer 3 won't be broken
+    // and the RDD chain will grow infinitely and cause StackOverflow.
+    //
+    // Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing
+    // all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break
+    // connections between layer 2 and layer 3)
+    ssc = new StreamingContext(master, framework, batchDuration)
+    val batchCounter = new BatchCounter(ssc)
+    ssc.checkpoint(checkpointDir)
+    val inputDStream = new CheckpointInputDStream(ssc)
+    val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+      Some(values.sum + state.getOrElse(0))
+    }
+    @volatile var shouldCheckpointAllMarkedRDDs = false
+    @volatile var rddsCheckpointed = false
+    inputDStream.map(i => (i, i))
+      .updateStateByKey(updateFunc).checkpoint(batchDuration)
+      .updateStateByKey(updateFunc).checkpoint(batchDuration)
+      .foreachRDD { rdd =>
+        /**
+         * Find all RDDs that are marked for checkpointing in the specified RDD and its ancestors.
+         */
+        def findAllMarkedRDDs(rdd: RDD[_]): List[RDD[_]] = {
+          val markedRDDs = rdd.dependencies.flatMap(dep => findAllMarkedRDDs(dep.rdd)).toList
+          if (rdd.checkpointData.isDefined) {
+            rdd :: markedRDDs
+          } else {
+            markedRDDs
+          }
+        }
+
+        shouldCheckpointAllMarkedRDDs =
+          Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).
+            map(_.toBoolean).getOrElse(false)
+
+        val stateRDDs = findAllMarkedRDDs(rdd)
+        rdd.count()
+        // Check the two state RDDs are both checkpointed
+        rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed)
+      }
+    ssc.start()
+    batchCounter.waitUntilBatchesCompleted(1, 10000)
+    assert(shouldCheckpointAllMarkedRDDs === true)
+    assert(rddsCheckpointed === true)
+  }
+
   /**
    * Advances the manual clock on the streaming scheduler by given number of batches.
    * It also waits for the expected amount of time for each batch.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org