You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/11/17 21:48:23 UTC

spark git commit: [SPARK-4180] [Core] Prevent creation of multiple active SparkContexts

Repository: spark
Updated Branches:
  refs/heads/master cec1116b4 -> 0f3ceb56c


[SPARK-4180] [Core] Prevent creation of multiple active SparkContexts

This patch adds error-detection logic to throw an exception when attempting to create multiple active SparkContexts in the same JVM, since this is currently unsupported and has been known to cause confusing behavior (see SPARK-2243 for more details).

**The solution implemented here is only a partial fix.**  A complete fix would have the following properties:

1. Only one SparkContext may ever be under construction at any given time.
2. Once a SparkContext has been successfully constructed, any subsequent construction attempts should fail until the active SparkContext is stopped.
3. If the SparkContext constructor throws an exception, then all resources created in the constructor should be cleaned up (SPARK-4194).
4. If a user attempts to create a SparkContext but the creation fails, then the user should be able to create new SparkContexts.

This PR only provides 2) and 4); we should be able to provide all of these properties, but the correct fix will involve larger changes to SparkContext's construction / initialization, so we'll target it for a different Spark release.

### The correct solution:

I think that the correct way to do this would be to move the construction of SparkContext's dependencies into a static method in the SparkContext companion object.  Specifically, we could make the default SparkContext constructor `private` and change it to accept a `SparkContextDependencies` object that contains all of SparkContext's dependencies (e.g. DAGScheduler, ContextCleaner, etc.).  Secondary constructors could call a method on the SparkContext companion object to create the `SparkContextDependencies` and pass the result to the primary SparkContext constructor.  For example:

```scala
class SparkContext private (deps: SparkContextDependencies) {
  def this(conf: SparkConf) {
    this(SparkContext.getDeps(conf))
  }
}

object SparkContext(
  private[spark] def getDeps(conf: SparkConf): SparkContextDependencies = synchronized {
    if (anotherSparkContextIsActive) { throw Exception(...) }
    var dagScheduler: DAGScheduler = null
    try {
        dagScheduler = new DAGScheduler(...)
        [...]
    } catch {
      case e: Exception =>
         Option(dagScheduler).foreach(_.stop())
          [...]
    }
    SparkContextDependencies(dagScheduler, ....)
  }
}
```

This gives us mutual exclusion and ensures that any resources created during the failed SparkContext initialization are properly cleaned up.

This indirection is necessary to maintain binary compatibility.  In retrospect, it would have been nice if SparkContext had no private constructors and could only be created through builder / factory methods on its companion object, since this buys us lots of flexibility and makes dependency injection easier.

### Alternative solutions:

As an alternative solution, we could refactor SparkContext's primary constructor to perform all object creation in a giant `try-finally` block.  Unfortunately, this will require us to turn a bunch of `vals` into `vars` so that they can be assigned from the `try` block.  If we still want `vals`, we could wrap each `val` in its own `try` block (since the try block can return a value), but this will lead to extremely messy code and won't guard against the introduction of future code which doesn't properly handle failures.

The more complex approach outlined above gives us some nice dependency injection benefits, so I think that might be preferable to a `var`-ification.

### This PR's solution:

- At the start of the constructor, check whether some other SparkContext is active; if so, throw an exception.
- If another SparkContext might be under construction (or has thrown an exception during construction), allow the new SparkContext to begin construction but log a warning (since resources might have been leaked from a failed creation attempt).
- At the end of the SparkContext constructor, check whether some other SparkContext constructor has raced and successfully created an active context.  If so, throw an exception.

This guarantees that no two SparkContexts will ever be active and exposed to users (since we check at the very end of the constructor).  If two threads race to construct SparkContexts, then one of them will win and another will throw an exception.

This exception can be turned into a warning by setting `spark.driver.allowMultipleContexts = true`.  The exception is disabled in unit tests, since there are some suites (such as Hive) that may require more significant refactoring to clean up their SparkContexts.  I've made a few changes to other suites' test fixtures to properly clean up SparkContexts so that the unit test logs contain fewer warnings.

Author: Josh Rosen <jo...@databricks.com>

Closes #3121 from JoshRosen/SPARK-4180 and squashes the following commits:

23c7123 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4180
d38251b [Josh Rosen] Address latest round of feedback.
c0987d3 [Josh Rosen] Accept boolean instead of SparkConf in methods.
85a424a [Josh Rosen] Incorporate more review feedback.
372d0d3 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4180
f5bb78c [Josh Rosen] Update mvn build, too.
d809cb4 [Josh Rosen] Improve handling of failed SparkContext creation attempts.
79a7e6f [Josh Rosen] Fix commented out test
a1cba65 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4180
7ba6db8 [Josh Rosen] Add utility to set system properties in tests.
4629d5c [Josh Rosen] Set spark.driver.allowMultipleContexts=true in tests.
ed17e14 [Josh Rosen] Address review feedback; expose hack workaround for existing unit tests.
1c66070 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4180
06c5c54 [Josh Rosen] Add / improve SparkContext cleanup in streaming BasicOperationsSuite
d0437eb [Josh Rosen] StreamingContext.stop() should stop SparkContext even if StreamingContext has not been started yet.
c4d35a2 [Josh Rosen] Log long form of creation site to aid debugging.
918e878 [Josh Rosen] Document "one SparkContext per JVM" limitation.
afaa7e3 [Josh Rosen] [SPARK-4180] Prevent creations of multiple active SparkContexts.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f3ceb56
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f3ceb56
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f3ceb56

Branch: refs/heads/master
Commit: 0f3ceb56c78e7260725a09fba0e10aa193cbda4b
Parents: cec1116
Author: Josh Rosen <jo...@databricks.com>
Authored: Mon Nov 17 12:48:18 2014 -0800
Committer: Patrick Wendell <pw...@gmail.com>
Committed: Mon Nov 17 12:48:18 2014 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/SparkContext.scala   | 167 ++++++++++++++---
 .../spark/api/java/JavaSparkContext.scala       |   3 +
 .../spark/ExecutorAllocationManagerSuite.scala  |   4 +
 .../org/apache/spark/SparkContextSuite.scala    |  57 +++++-
 docs/programming-guide.md                       |   2 +
 pom.xml                                         |   1 +
 project/SparkBuild.scala                        |   1 +
 .../spark/streaming/BasicOperationsSuite.scala  | 186 +++++++++----------
 .../apache/spark/streaming/TestSuiteBase.scala  |  52 +++++-
 9 files changed, 347 insertions(+), 126 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0f3ceb56/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 65edeef..7cccf74 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -58,12 +58,26 @@ import org.apache.spark.util._
  * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
  * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
  *
+ * Only one SparkContext may be active per JVM.  You must `stop()` the active SparkContext before
+ * creating a new one.  This limitation may eventually be removed; see SPARK-2243 for more details.
+ *
  * @param config a Spark Config object describing the application configuration. Any settings in
  *   this config overrides the default configs as well as system properties.
  */
-
 class SparkContext(config: SparkConf) extends Logging {
 
+  // The call site where this SparkContext was constructed.
+  private val creationSite: CallSite = Utils.getCallSite()
+
+  // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active
+  private val allowMultipleContexts: Boolean =
+    config.getBoolean("spark.driver.allowMultipleContexts", false)
+
+  // In order to prevent multiple SparkContexts from being active at the same time, mark this
+  // context as having started construction.
+  // NOTE: this must be placed at the beginning of the SparkContext constructor.
+  SparkContext.markPartiallyConstructed(this, allowMultipleContexts)
+
   // This is used only by YARN for now, but should be relevant to other cluster types (Mesos,
   // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It
   // contains a map from hostname to a list of input format splits on the host.
@@ -1166,27 +1180,30 @@ class SparkContext(config: SparkConf) extends Logging {
 
   /** Shut down the SparkContext. */
   def stop() {
-    postApplicationEnd()
-    ui.foreach(_.stop())
-    // Do this only if not stopped already - best case effort.
-    // prevent NPE if stopped more than once.
-    val dagSchedulerCopy = dagScheduler
-    dagScheduler = null
-    if (dagSchedulerCopy != null) {
-      env.metricsSystem.report()
-      metadataCleaner.cancel()
-      env.actorSystem.stop(heartbeatReceiver)
-      cleaner.foreach(_.stop())
-      dagSchedulerCopy.stop()
-      taskScheduler = null
-      // TODO: Cache.stop()?
-      env.stop()
-      SparkEnv.set(null)
-      listenerBus.stop()
-      eventLogger.foreach(_.stop())
-      logInfo("Successfully stopped SparkContext")
-    } else {
-      logInfo("SparkContext already stopped")
+    SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+      postApplicationEnd()
+      ui.foreach(_.stop())
+      // Do this only if not stopped already - best case effort.
+      // prevent NPE if stopped more than once.
+      val dagSchedulerCopy = dagScheduler
+      dagScheduler = null
+      if (dagSchedulerCopy != null) {
+        env.metricsSystem.report()
+        metadataCleaner.cancel()
+        env.actorSystem.stop(heartbeatReceiver)
+        cleaner.foreach(_.stop())
+        dagSchedulerCopy.stop()
+        taskScheduler = null
+        // TODO: Cache.stop()?
+        env.stop()
+        SparkEnv.set(null)
+        listenerBus.stop()
+        eventLogger.foreach(_.stop())
+        logInfo("Successfully stopped SparkContext")
+        SparkContext.clearActiveContext()
+      } else {
+        logInfo("SparkContext already stopped")
+      }
     }
   }
 
@@ -1475,6 +1492,11 @@ class SparkContext(config: SparkConf) extends Logging {
   private[spark] def cleanup(cleanupTime: Long) {
     persistentRdds.clearOldValues(cleanupTime)
   }
+
+  // In order to prevent multiple SparkContexts from being active at the same time, mark this
+  // context as having finished construction.
+  // NOTE: this must be placed at the end of the SparkContext constructor.
+  SparkContext.setActiveContext(this, allowMultipleContexts)
 }
 
 /**
@@ -1483,6 +1505,107 @@ class SparkContext(config: SparkConf) extends Logging {
  */
 object SparkContext extends Logging {
 
+  /**
+   * Lock that guards access to global variables that track SparkContext construction.
+   */
+  private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object()
+
+  /**
+   * The active, fully-constructed SparkContext.  If no SparkContext is active, then this is `None`.
+   *
+   * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+   */
+  private var activeContext: Option[SparkContext] = None
+
+  /**
+   * Points to a partially-constructed SparkContext if some thread is in the SparkContext
+   * constructor, or `None` if no SparkContext is being constructed.
+   *
+   * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+   */
+  private var contextBeingConstructed: Option[SparkContext] = None
+
+  /**
+   * Called to ensure that no other SparkContext is running in this JVM.
+   *
+   * Throws an exception if a running context is detected and logs a warning if another thread is
+   * constructing a SparkContext.  This warning is necessary because the current locking scheme
+   * prevents us from reliably distinguishing between cases where another context is being
+   * constructed and cases where another constructor threw an exception.
+   */
+  private def assertNoOtherContextIsRunning(
+      sc: SparkContext,
+      allowMultipleContexts: Boolean): Unit = {
+    SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+      contextBeingConstructed.foreach { otherContext =>
+        if (otherContext ne sc) {  // checks for reference equality
+          // Since otherContext might point to a partially-constructed context, guard against
+          // its creationSite field being null:
+          val otherContextCreationSite =
+            Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location")
+          val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" +
+            " constructor).  This may indicate an error, since only one SparkContext may be" +
+            " running in this JVM (see SPARK-2243)." +
+            s" The other SparkContext was created at:\n$otherContextCreationSite"
+          logWarning(warnMsg)
+        }
+
+        activeContext.foreach { ctx =>
+          val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." +
+            " To ignore this error, set spark.driver.allowMultipleContexts = true. " +
+            s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}"
+          val exception = new SparkException(errMsg)
+          if (allowMultipleContexts) {
+            logWarning("Multiple running SparkContexts detected in the same JVM!", exception)
+          } else {
+            throw exception
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is
+   * running.  Throws an exception if a running context is detected and logs a warning if another
+   * thread is constructing a SparkContext.  This warning is necessary because the current locking
+   * scheme prevents us from reliably distinguishing between cases where another context is being
+   * constructed and cases where another constructor threw an exception.
+   */
+  private[spark] def markPartiallyConstructed(
+      sc: SparkContext,
+      allowMultipleContexts: Boolean): Unit = {
+    SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+      assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+      contextBeingConstructed = Some(sc)
+    }
+  }
+
+  /**
+   * Called at the end of the SparkContext constructor to ensure that no other SparkContext has
+   * raced with this constructor and started.
+   */
+  private[spark] def setActiveContext(
+      sc: SparkContext,
+      allowMultipleContexts: Boolean): Unit = {
+    SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+      assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+      contextBeingConstructed = None
+      activeContext = Some(sc)
+    }
+  }
+
+  /**
+   * Clears the active SparkContext metadata.  This is called by `SparkContext#stop()`.  It's
+   * also called in unit tests to prevent a flood of warnings from test suites that don't / can't
+   * properly clean up their SparkContexts.
+   */
+  private[spark] def clearActiveContext(): Unit = {
+    SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+      activeContext = None
+    }
+  }
+
   private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
 
   private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3ceb56/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index d50ed32..6a6d9bf 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -42,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
 /**
  * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
  * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones.
+ *
+ * Only one SparkContext may be active per JVM.  You must `stop()` the active SparkContext before
+ * creating a new one.  This limitation may eventually be removed; see SPARK-2243 for more details.
  */
 class JavaSparkContext(val sc: SparkContext)
   extends JavaSparkContextVarargsWorkaround with Closeable {

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3ceb56/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index 4b27477..ce804f9 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -37,20 +37,24 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext {
       .set("spark.dynamicAllocation.enabled", "true")
     intercept[SparkException] { new SparkContext(conf) }
     SparkEnv.get.stop() // cleanup the created environment
+    SparkContext.clearActiveContext()
 
     // Only min
     val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1")
     intercept[SparkException] { new SparkContext(conf1) }
     SparkEnv.get.stop()
+    SparkContext.clearActiveContext()
 
     // Only max
     val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2")
     intercept[SparkException] { new SparkContext(conf2) }
     SparkEnv.get.stop()
+    SparkContext.clearActiveContext()
 
     // Both min and max, but min > max
     intercept[SparkException] { createSparkContext(2, 1) }
     SparkEnv.get.stop()
+    SparkContext.clearActiveContext()
 
     // Both min and max, and min == max
     val sc1 = createSparkContext(1, 1)

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3ceb56/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index 31edad1..9e454dd 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -21,9 +21,62 @@ import org.scalatest.FunSuite
 
 import org.apache.hadoop.io.BytesWritable
 
-class SparkContextSuite extends FunSuite {
-  //Regression test for SPARK-3121
+class SparkContextSuite extends FunSuite with LocalSparkContext {
+
+  /** Allows system properties to be changed in tests */
+  private def withSystemProperty[T](property: String, value: String)(block: => T): T = {
+    val originalValue = System.getProperty(property)
+    try {
+      System.setProperty(property, value)
+      block
+    } finally {
+      if (originalValue == null) {
+        System.clearProperty(property)
+      } else {
+        System.setProperty(property, originalValue)
+      }
+    }
+  }
+
+  test("Only one SparkContext may be active at a time") {
+    // Regression test for SPARK-4180
+    withSystemProperty("spark.driver.allowMultipleContexts", "false") {
+      val conf = new SparkConf().setAppName("test").setMaster("local")
+      sc = new SparkContext(conf)
+      // A SparkContext is already running, so we shouldn't be able to create a second one
+      intercept[SparkException] { new SparkContext(conf) }
+      // After stopping the running context, we should be able to create a new one
+      resetSparkContext()
+      sc = new SparkContext(conf)
+    }
+  }
+
+  test("Can still construct a new SparkContext after failing to construct a previous one") {
+    withSystemProperty("spark.driver.allowMultipleContexts", "false") {
+      // This is an invalid configuration (no app name or master URL)
+      intercept[SparkException] {
+        new SparkContext(new SparkConf())
+      }
+      // Even though those earlier calls failed, we should still be able to create a new context
+      sc = new SparkContext(new SparkConf().setMaster("local").setAppName("test"))
+    }
+  }
+
+  test("Check for multiple SparkContexts can be disabled via undocumented debug option") {
+    withSystemProperty("spark.driver.allowMultipleContexts", "true") {
+      var secondSparkContext: SparkContext = null
+      try {
+        val conf = new SparkConf().setAppName("test").setMaster("local")
+        sc = new SparkContext(conf)
+        secondSparkContext = new SparkContext(conf)
+      } finally {
+        Option(secondSparkContext).foreach(_.stop())
+      }
+    }
+  }
+
   test("BytesWritable implicit conversion is correct") {
+    // Regression test for SPARK-3121
     val bytesWritable = new BytesWritable()
     val inputArray = (1 to 10).map(_.toByte).toArray
     bytesWritable.set(inputArray, 0, 10)

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3ceb56/docs/programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 9de2f91..49f319b 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -117,6 +117,8 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/
 how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object
 that contains information about your application.
 
+Only one SparkContext may be active per JVM.  You must `stop()` the active SparkContext before creating a new one.
+
 {% highlight scala %}
 val conf = new SparkConf().setAppName(appName).setMaster(master)
 new SparkContext(conf)

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3ceb56/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 639ea22..cc7bce1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -978,6 +978,7 @@
               <spark.testing>1</spark.testing>
               <spark.ui.enabled>false</spark.ui.enabled>
               <spark.executor.extraClassPath>${test_classpath}</spark.executor.extraClassPath>
+              <spark.driver.allowMultipleContexts>true</spark.driver.allowMultipleContexts>
             </systemProperties>
           </configuration>
           <executions>

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3ceb56/project/SparkBuild.scala
----------------------------------------------------------------------
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index c96a6c4..1697b6d 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -377,6 +377,7 @@ object TestSettings {
     javaOptions in Test += "-Dspark.testing=1",
     javaOptions in Test += "-Dspark.port.maxRetries=100",
     javaOptions in Test += "-Dspark.ui.enabled=false",
+    javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
     javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
     javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
       .map { case (k,v) => s"-D$k=$v" }.toSeq,

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3ceb56/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 30a3596..86b9678 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -470,32 +470,31 @@ class BasicOperationsSuite extends TestSuiteBase {
   }
 
   test("slice") {
-    val ssc = new StreamingContext(conf, Seconds(1))
-    val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
-    val stream = new TestInputStream[Int](ssc, input, 2)
-    stream.foreachRDD(_ => {})  // Dummy output stream
-    ssc.start()
-    Thread.sleep(2000)
-    def getInputFromSlice(fromMillis: Long, toMillis: Long) = {
-      stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet
-    }
+    withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc =>
+      val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
+      val stream = new TestInputStream[Int](ssc, input, 2)
+      stream.foreachRDD(_ => {})  // Dummy output stream
+      ssc.start()
+      Thread.sleep(2000)
+      def getInputFromSlice(fromMillis: Long, toMillis: Long) = {
+        stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet
+      }
 
-    assert(getInputFromSlice(0, 1000) == Set(1))
-    assert(getInputFromSlice(0, 2000) == Set(1, 2))
-    assert(getInputFromSlice(1000, 2000) == Set(1, 2))
-    assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4))
-    ssc.stop()
-    Thread.sleep(1000)
+      assert(getInputFromSlice(0, 1000) == Set(1))
+      assert(getInputFromSlice(0, 2000) == Set(1, 2))
+      assert(getInputFromSlice(1000, 2000) == Set(1, 2))
+      assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4))
+    }
   }
-
   test("slice - has not been initialized") {
-    val ssc = new StreamingContext(conf, Seconds(1))
-    val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
-    val stream = new TestInputStream[Int](ssc, input, 2)
-    val thrown = intercept[SparkException] {
-      stream.slice(new Time(0), new Time(1000))
+    withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc =>
+      val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
+      val stream = new TestInputStream[Int](ssc, input, 2)
+      val thrown = intercept[SparkException] {
+        stream.slice(new Time(0), new Time(1000))
+      }
+      assert(thrown.getMessage.contains("has not been initialized"))
     }
-    assert(thrown.getMessage.contains("has not been initialized"))
   }
 
   val cleanupTestInput = (0 until 10).map(x => Seq(x, x + 1)).toSeq
@@ -555,73 +554,72 @@ class BasicOperationsSuite extends TestSuiteBase {
   test("rdd cleanup - input blocks and persisted RDDs") {
     // Actually receive data over through receiver to create BlockRDDs
 
-    // Start the server
-    val testServer = new TestServer()
-    testServer.start()
-
-    // Set up the streaming context and input streams
-    val ssc = new StreamingContext(conf, batchDuration)
-    val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
-    val mappedStream = networkStream.map(_ + ".").persist()
-    val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
-    val outputStream = new TestOutputStream(mappedStream, outputBuffer)
-
-    outputStream.register()
-    ssc.start()
-
-    // Feed data to the server to send to the network receiver
-    val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
-    val input = Seq(1, 2, 3, 4, 5, 6)
+    withTestServer(new TestServer()) { testServer =>
+      withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
+        testServer.start()
+        // Set up the streaming context and input streams
+        val networkStream =
+          ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
+        val mappedStream = networkStream.map(_ + ".").persist()
+        val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
+        val outputStream = new TestOutputStream(mappedStream, outputBuffer)
+
+        outputStream.register()
+        ssc.start()
+
+        // Feed data to the server to send to the network receiver
+        val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+        val input = Seq(1, 2, 3, 4, 5, 6)
+
+        val blockRdds = new mutable.HashMap[Time, BlockRDD[_]]
+        val persistentRddIds = new mutable.HashMap[Time, Int]
+
+        def collectRddInfo() { // get all RDD info required for verification
+          networkStream.generatedRDDs.foreach { case (time, rdd) =>
+            blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]]
+          }
+          mappedStream.generatedRDDs.foreach { case (time, rdd) =>
+            persistentRddIds(time) = rdd.id
+          }
+        }
 
-    val blockRdds = new mutable.HashMap[Time, BlockRDD[_]]
-    val persistentRddIds = new mutable.HashMap[Time, Int]
+        Thread.sleep(200)
+        for (i <- 0 until input.size) {
+          testServer.send(input(i).toString + "\n")
+          Thread.sleep(200)
+          clock.addToTime(batchDuration.milliseconds)
+          collectRddInfo()
+        }
 
-    def collectRddInfo() { // get all RDD info required for verification
-      networkStream.generatedRDDs.foreach { case (time, rdd) =>
-        blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]]
-      }
-      mappedStream.generatedRDDs.foreach { case (time, rdd) =>
-        persistentRddIds(time) = rdd.id
+        Thread.sleep(200)
+        collectRddInfo()
+        logInfo("Stopping server")
+        testServer.stop()
+
+        // verify data has been received
+        assert(outputBuffer.size > 0)
+        assert(blockRdds.size > 0)
+        assert(persistentRddIds.size > 0)
+
+        import Time._
+
+        val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max)
+        val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min)
+        val latestBlockRdd = blockRdds(blockRdds.keySet.max)
+        val earliestBlockRdd = blockRdds(blockRdds.keySet.min)
+        // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted
+        assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId))
+        assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId))
+
+        // verify that the latest input blocks are present but the earliest blocks have been removed
+        assert(latestBlockRdd.isValid)
+        assert(latestBlockRdd.collect != null)
+        assert(!earliestBlockRdd.isValid)
+        earliestBlockRdd.blockIds.foreach { blockId =>
+          assert(!ssc.sparkContext.env.blockManager.master.contains(blockId))
+        }
       }
     }
-
-    Thread.sleep(200)
-    for (i <- 0 until input.size) {
-      testServer.send(input(i).toString + "\n")
-      Thread.sleep(200)
-      clock.addToTime(batchDuration.milliseconds)
-      collectRddInfo()
-    }
-
-    Thread.sleep(200)
-    collectRddInfo()
-    logInfo("Stopping server")
-    testServer.stop()
-    logInfo("Stopping context")
-
-    // verify data has been received
-    assert(outputBuffer.size > 0)
-    assert(blockRdds.size > 0)
-    assert(persistentRddIds.size > 0)
-
-    import Time._
-
-    val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max)
-    val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min)
-    val latestBlockRdd = blockRdds(blockRdds.keySet.max)
-    val earliestBlockRdd = blockRdds(blockRdds.keySet.min)
-    // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted
-    assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId))
-    assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId))
-
-    // verify that the latest input blocks are present but the earliest blocks have been removed
-    assert(latestBlockRdd.isValid)
-    assert(latestBlockRdd.collect != null)
-    assert(!earliestBlockRdd.isValid)
-    earliestBlockRdd.blockIds.foreach { blockId =>
-      assert(!ssc.sparkContext.env.blockManager.master.contains(blockId))
-    }
-    ssc.stop()
   }
 
   /** Test cleanup of RDDs in DStream metadata */
@@ -635,13 +633,15 @@ class BasicOperationsSuite extends TestSuiteBase {
     // Setup the stream computation
     assert(batchDuration === Seconds(1),
       "Batch duration has changed from 1 second, check cleanup tests")
-    val ssc = setupStreams(cleanupTestInput, operation)
-    val operatedStream = ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]]
-    if (rememberDuration != null) ssc.remember(rememberDuration)
-    val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput)
-    val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
-    assert(clock.time === Seconds(10).milliseconds)
-    assert(output.size === numExpectedOutput)
-    operatedStream
+    withStreamingContext(setupStreams(cleanupTestInput, operation)) { ssc =>
+      val operatedStream =
+        ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]]
+      if (rememberDuration != null) ssc.remember(rememberDuration)
+      val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput)
+      val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+      assert(clock.time === Seconds(10).milliseconds)
+      assert(output.size === numExpectedOutput)
+      operatedStream
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0f3ceb56/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 2154c24..52972f6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -164,6 +164,40 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
   after(afterFunction)
 
   /**
+   * Run a block of code with the given StreamingContext and automatically
+   * stop the context when the block completes or when an exception is thrown.
+   */
+  def withStreamingContext[R](ssc: StreamingContext)(block: StreamingContext => R): R = {
+    try {
+      block(ssc)
+    } finally {
+      try {
+        ssc.stop(stopSparkContext = true)
+      } catch {
+        case e: Exception =>
+          logError("Error stopping StreamingContext", e)
+      }
+    }
+  }
+
+  /**
+   * Run a block of code with the given TestServer and automatically
+   * stop the server when the block completes or when an exception is thrown.
+   */
+  def withTestServer[R](testServer: TestServer)(block: TestServer => R): R = {
+    try {
+      block(testServer)
+    } finally {
+      try {
+        testServer.stop()
+      } catch {
+        case e: Exception =>
+          logError("Error stopping TestServer", e)
+      }
+    }
+  }
+
+  /**
    * Set up required DStreams to test the DStream operation using the two sequences
    * of input collections.
    */
@@ -282,10 +316,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
       assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
 
       Thread.sleep(100) // Give some time for the forgetting old RDDs to complete
-    } catch {
-      case e: Exception => {e.printStackTrace(); throw e}
     } finally {
-      ssc.stop()
+      ssc.stop(stopSparkContext = true)
     }
     output
   }
@@ -351,9 +383,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
       useSet: Boolean
     ) {
     val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
-    val ssc = setupStreams[U, V](input, operation)
-    val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
-    verifyOutput[V](output, expectedOutput, useSet)
+    withStreamingContext(setupStreams[U, V](input, operation)) { ssc =>
+      val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
+      verifyOutput[V](output, expectedOutput, useSet)
+    }
   }
 
   /**
@@ -389,8 +422,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
       useSet: Boolean
     ) {
     val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
-    val ssc = setupStreams[U, V, W](input1, input2, operation)
-    val output = runStreams[W](ssc, numBatches_, expectedOutput.size)
-    verifyOutput[W](output, expectedOutput, useSet)
+    withStreamingContext(setupStreams[U, V, W](input1, input2, operation)) { ssc =>
+      val output = runStreams[W](ssc, numBatches_, expectedOutput.size)
+      verifyOutput[W](output, expectedOutput, useSet)
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org