You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2015/04/23 20:29:34 UTC

spark git commit: [SPARK-6752][Streaming] Allow StreamingContext to be recreated from checkpoint and existing SparkContext

Repository: spark
Updated Branches:
  refs/heads/master cc48e6387 -> 534f2a436


[SPARK-6752][Streaming] Allow StreamingContext to be recreated from checkpoint and existing SparkContext

Currently if you want to create a StreamingContext from checkpoint information, the system will create a new SparkContext. This prevent StreamingContext to be recreated from checkpoints in managed environments where SparkContext is precreated.

The solution in this PR: Introduce the following methods on StreamingContext
1. `new StreamingContext(checkpointDirectory, sparkContext)`
   Recreate StreamingContext from checkpoint using the provided SparkContext
2. `StreamingContext.getOrCreate(checkpointDirectory, sparkContext, createFunction: SparkContext => StreamingContext)`
   If checkpoint file exists, then recreate StreamingContext using the provided SparkContext (that is, call 1.), else create StreamingContext using the provided createFunction

TODO: the corresponding Java and Python API has to be added as well.

Author: Tathagata Das <ta...@gmail.com>

Closes #5428 from tdas/SPARK-6752 and squashes the following commits:

94db63c [Tathagata Das] Fix long line.
524f519 [Tathagata Das] Many changes based on PR comments.
eabd092 [Tathagata Das] Added Function0, Java API and unit tests for StreamingContext.getOrCreate
36a7823 [Tathagata Das] Minor changes.
204814e [Tathagata Das] Added StreamingContext.getOrCreate with existing SparkContext


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/534f2a43
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/534f2a43
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/534f2a43

Branch: refs/heads/master
Commit: 534f2a43625fbf1a3a65d09550a19875cd1dce43
Parents: cc48e63
Author: Tathagata Das <ta...@gmail.com>
Authored: Thu Apr 23 11:29:34 2015 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Thu Apr 23 11:29:34 2015 -0700

----------------------------------------------------------------------
 .../spark/api/java/function/Function0.java      |  27 ++++
 .../org/apache/spark/streaming/Checkpoint.scala |  26 ++-
 .../spark/streaming/StreamingContext.scala      |  85 ++++++++--
 .../api/java/JavaStreamingContext.scala         | 119 +++++++++++++-
 .../apache/spark/streaming/JavaAPISuite.java    | 145 ++++++++++++-----
 .../spark/streaming/CheckpointSuite.scala       |   3 +-
 .../spark/streaming/StreamingContextSuite.scala | 159 +++++++++++++++++++
 7 files changed, 503 insertions(+), 61 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/534f2a43/core/src/main/java/org/apache/spark/api/java/function/Function0.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java
new file mode 100644
index 0000000..38e410c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * A zero-argument function that returns an R.
+ */
+public interface Function0<R> extends Serializable {
+  public R call() throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/534f2a43/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 0a50485..7bfae25 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -77,7 +77,8 @@ object Checkpoint extends Logging {
   }
 
   /** Get checkpoint files present in the give directory, ordered by oldest-first */
-  def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = {
+  def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = {
+
     def sortFunc(path1: Path, path2: Path): Boolean = {
       val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
       val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
@@ -85,6 +86,7 @@ object Checkpoint extends Logging {
     }
 
     val path = new Path(checkpointDir)
+    val fs = fsOption.getOrElse(path.getFileSystem(new Configuration()))
     if (fs.exists(path)) {
       val statuses = fs.listStatus(path)
       if (statuses != null) {
@@ -160,7 +162,7 @@ class CheckpointWriter(
           }
 
           // Delete old checkpoint files
-          val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs)
+          val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs))
           if (allCheckpointFiles.size > 10) {
             allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => {
               logInfo("Deleting " + file)
@@ -234,15 +236,24 @@ class CheckpointWriter(
 private[streaming]
 object CheckpointReader extends Logging {
 
-  def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] =
-  {
+  /**
+   * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint
+   * files, then return None, else try to return the latest valid checkpoint object. If no
+   * checkpoint files could be read correctly, then return None (if ignoreReadError = true),
+   * or throw exception (if ignoreReadError = false).
+   */
+  def read(
+      checkpointDir: String,
+      conf: SparkConf,
+      hadoopConf: Configuration,
+      ignoreReadError: Boolean = false): Option[Checkpoint] = {
     val checkpointPath = new Path(checkpointDir)
 
     // TODO(rxin): Why is this a def?!
     def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf)
 
     // Try to find the checkpoint files
-    val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse
+    val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse
     if (checkpointFiles.isEmpty) {
       return None
     }
@@ -282,7 +293,10 @@ object CheckpointReader extends Logging {
     })
 
     // If none of checkpoint files could be read, then throw exception
-    throw new SparkException("Failed to read checkpoint from directory " + checkpointPath)
+    if (!ignoreReadError) {
+      throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath")
+    }
+    None
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/534f2a43/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index f57f295..90c8b47 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -107,6 +107,19 @@ class StreamingContext private[streaming] (
    */
   def this(path: String) = this(path, new Configuration)
 
+  /**
+   * Recreate a StreamingContext from a checkpoint file using an existing SparkContext.
+   * @param path Path to the directory that was specified as the checkpoint directory
+   * @param sparkContext Existing SparkContext
+   */
+  def this(path: String, sparkContext: SparkContext) = {
+    this(
+      sparkContext,
+      CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get,
+      null)
+  }
+
+
   if (sc_ == null && cp_ == null) {
     throw new Exception("Spark Streaming cannot be initialized with " +
       "both SparkContext and checkpoint as null")
@@ -115,10 +128,12 @@ class StreamingContext private[streaming] (
   private[streaming] val isCheckpointPresent = (cp_ != null)
 
   private[streaming] val sc: SparkContext = {
-    if (isCheckpointPresent) {
+    if (sc_ != null) {
+      sc_
+    } else if (isCheckpointPresent) {
       new SparkContext(cp_.createSparkConf())
     } else {
-      sc_
+      throw new SparkException("Cannot create StreamingContext without a SparkContext")
     }
   }
 
@@ -129,7 +144,7 @@ class StreamingContext private[streaming] (
 
   private[streaming] val conf = sc.conf
 
-  private[streaming] val env = SparkEnv.get
+  private[streaming] val env = sc.env
 
   private[streaming] val graph: DStreamGraph = {
     if (isCheckpointPresent) {
@@ -174,7 +189,9 @@ class StreamingContext private[streaming] (
 
   /** Register streaming source to metrics system */
   private val streamingSource = new StreamingSource(this)
-  SparkEnv.get.metricsSystem.registerSource(streamingSource)
+  assert(env != null)
+  assert(env.metricsSystem != null)
+  env.metricsSystem.registerSource(streamingSource)
 
   /** Enumeration to identify current state of the StreamingContext */
   private[streaming] object StreamingContextState extends Enumeration {
@@ -621,19 +638,59 @@ object StreamingContext extends Logging {
       hadoopConf: Configuration = new Configuration(),
       createOnError: Boolean = false
     ): StreamingContext = {
-    val checkpointOption = try {
-      CheckpointReader.read(checkpointPath,  new SparkConf(), hadoopConf)
-    } catch {
-      case e: Exception =>
-        if (createOnError) {
-          None
-        } else {
-          throw e
-        }
-    }
+    val checkpointOption = CheckpointReader.read(
+      checkpointPath, new SparkConf(), hadoopConf, createOnError)
     checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc())
   }
 
+
+  /**
+   * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+   * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+   * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+   * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note
+   * that the SparkConf configuration in the checkpoint data will not be restored as the
+   * SparkContext has already been created.
+   *
+   * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+   * @param creatingFunc   Function to create a new StreamingContext using the given SparkContext
+   * @param sparkContext   SparkContext using which the StreamingContext will be created
+   */
+  def getOrCreate(
+      checkpointPath: String,
+      creatingFunc: SparkContext => StreamingContext,
+      sparkContext: SparkContext
+    ): StreamingContext = {
+    getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false)
+  }
+
+  /**
+   * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+   * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+   * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+   * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note
+   * that the SparkConf configuration in the checkpoint data will not be restored as the
+   * SparkContext has already been created.
+   *
+   * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+   * @param creatingFunc   Function to create a new StreamingContext using the given SparkContext
+   * @param sparkContext   SparkContext using which the StreamingContext will be created
+   * @param createOnError  Whether to create a new StreamingContext if there is an
+   *                       error in reading checkpoint data. By default, an exception will be
+   *                       thrown on error.
+   */
+  def getOrCreate(
+      checkpointPath: String,
+      creatingFunc: SparkContext => StreamingContext,
+      sparkContext: SparkContext,
+      createOnError: Boolean
+    ): StreamingContext = {
+    val checkpointOption = CheckpointReader.read(
+      checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError)
+    checkpointOption.map(new StreamingContext(sparkContext, _, null))
+                    .getOrElse(creatingFunc(sparkContext))
+  }
+
   /**
    * Find the JAR from which a given class was loaded, to make it easy for users to pass
    * their JARs to StreamingContext.

http://git-wip-us.apache.org/repos/asf/spark/blob/534f2a43/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 4095a7c..572d7d8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -32,13 +32,14 @@ import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
 import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import org.apache.spark.api.java.function.{Function0 => JFunction0}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming._
 import org.apache.spark.streaming.scheduler.StreamingListener
-import org.apache.hadoop.conf.Configuration
-import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream}
+import org.apache.spark.streaming.dstream.DStream
 import org.apache.spark.streaming.receiver.Receiver
+import org.apache.hadoop.conf.Configuration
 
 /**
  * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main
@@ -655,6 +656,7 @@ object JavaStreamingContext {
    * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
    * @param factory        JavaStreamingContextFactory object to create a new JavaStreamingContext
    */
+  @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0")
   def getOrCreate(
       checkpointPath: String,
       factory: JavaStreamingContextFactory
@@ -676,6 +678,7 @@ object JavaStreamingContext {
    * @param hadoopConf     Hadoop configuration if necessary for reading from any HDFS compatible
    *                       file system
    */
+  @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0")
   def getOrCreate(
       checkpointPath: String,
       hadoopConf: Configuration,
@@ -700,6 +703,7 @@ object JavaStreamingContext {
    * @param createOnError  Whether to create a new JavaStreamingContext if there is an
    *                       error in reading checkpoint data.
    */
+  @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0")
   def getOrCreate(
       checkpointPath: String,
       hadoopConf: Configuration,
@@ -713,6 +717,117 @@ object JavaStreamingContext {
   }
 
   /**
+   * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+   * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+   * recreated from the checkpoint data. If the data does not exist, then the provided factory
+   * will be used to create a JavaStreamingContext.
+   *
+   * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
+   * @param creatingFunc   Function to create a new JavaStreamingContext
+   */
+  def getOrCreate(
+      checkpointPath: String,
+      creatingFunc: JFunction0[JavaStreamingContext]
+    ): JavaStreamingContext = {
+    val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+      creatingFunc.call().ssc
+    })
+    new JavaStreamingContext(ssc)
+  }
+
+  /**
+   * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+   * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+   * recreated from the checkpoint data. If the data does not exist, then the provided factory
+   * will be used to create a JavaStreamingContext.
+   *
+   * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+   * @param creatingFunc   Function to create a new JavaStreamingContext
+   * @param hadoopConf     Hadoop configuration if necessary for reading from any HDFS compatible
+   *                       file system
+   */
+  def getOrCreate(
+      checkpointPath: String,
+      creatingFunc: JFunction0[JavaStreamingContext],
+      hadoopConf: Configuration
+    ): JavaStreamingContext = {
+    val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+      creatingFunc.call().ssc
+    }, hadoopConf)
+    new JavaStreamingContext(ssc)
+  }
+
+  /**
+   * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+   * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+   * recreated from the checkpoint data. If the data does not exist, then the provided factory
+   * will be used to create a JavaStreamingContext.
+   *
+   * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+   * @param creatingFunc   Function to create a new JavaStreamingContext
+   * @param hadoopConf     Hadoop configuration if necessary for reading from any HDFS compatible
+   *                       file system
+   * @param createOnError  Whether to create a new JavaStreamingContext if there is an
+   *                       error in reading checkpoint data.
+   */
+  def getOrCreate(
+      checkpointPath: String,
+      creatingFunc: JFunction0[JavaStreamingContext],
+      hadoopConf: Configuration,
+      createOnError: Boolean
+    ): JavaStreamingContext = {
+    val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+      creatingFunc.call().ssc
+    }, hadoopConf, createOnError)
+    new JavaStreamingContext(ssc)
+  }
+
+  /**
+   * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+   * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+   * recreated from the checkpoint data. If the data does not exist, then the provided factory
+   * will be used to create a JavaStreamingContext.
+   *
+   * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+   * @param creatingFunc   Function to create a new JavaStreamingContext
+   * @param sparkContext   SparkContext using which the StreamingContext will be created
+   */
+  def getOrCreate(
+      checkpointPath: String,
+      creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext],
+      sparkContext: JavaSparkContext
+    ): JavaStreamingContext = {
+    val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => {
+      creatingFunc.call(new JavaSparkContext(sparkContext)).ssc
+    }, sparkContext.sc)
+    new JavaStreamingContext(ssc)
+  }
+
+  /**
+   * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+   * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+   * recreated from the checkpoint data. If the data does not exist, then the provided factory
+   * will be used to create a JavaStreamingContext.
+   *
+   * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+   * @param creatingFunc   Function to create a new JavaStreamingContext
+   * @param sparkContext   SparkContext using which the StreamingContext will be created
+   * @param createOnError  Whether to create a new JavaStreamingContext if there is an
+   *                       error in reading checkpoint data.
+   */
+  def getOrCreate(
+      checkpointPath: String,
+      creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext],
+      sparkContext: JavaSparkContext,
+      createOnError: Boolean
+    ): JavaStreamingContext = {
+    val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => {
+      creatingFunc.call(new JavaSparkContext(sparkContext)).ssc
+    }, sparkContext.sc, createOnError)
+    new JavaStreamingContext(ssc)
+  }
+
+  /**
    * Find the JAR from which a given class was loaded, to make it easy for users to pass
    * their JARs to StreamingContext.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/534f2a43/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 9034075..cb2e838 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -22,10 +22,12 @@ import java.lang.Iterable;
 import java.nio.charset.Charset;
 import java.util.*;
 
+import org.apache.commons.lang.mutable.MutableBoolean;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.LongWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+
 import scala.Tuple2;
 
 import org.junit.Assert;
@@ -45,6 +47,7 @@ import org.apache.spark.api.java.function.*;
 import org.apache.spark.storage.StorageLevel;
 import org.apache.spark.streaming.api.java.*;
 import org.apache.spark.util.Utils;
+import org.apache.spark.SparkConf;
 
 // The test suite itself is Serializable so that anonymous Function implementations can be
 // serialized, as an alternative to converting these anonymous classes to static inner classes;
@@ -929,7 +932,7 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
           public Tuple2<Integer, String> call(Tuple2<String, Integer> in) throws Exception {
             return in.swap();
           }
-    });
+        });
 
     JavaTestUtils.attachTestOutputStream(reversed);
     List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
@@ -987,12 +990,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
     JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
     JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
     JavaDStream<Integer> reversed = pairStream.map(
-            new Function<Tuple2<String, Integer>, Integer>() {
-              @Override
-              public Integer call(Tuple2<String, Integer> in) throws Exception {
-                return in._2();
-              }
-            });
+        new Function<Tuple2<String, Integer>, Integer>() {
+          @Override
+          public Integer call(Tuple2<String, Integer> in) throws Exception {
+            return in._2();
+          }
+        });
 
     JavaTestUtils.attachTestOutputStream(reversed);
     List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 2, 2);
@@ -1123,7 +1126,7 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
 
     JavaPairDStream<String, Integer> combined = pairStream.<Integer>combineByKey(
         new Function<Integer, Integer>() {
-        @Override
+          @Override
           public Integer call(Integer i) throws Exception {
             return i;
           }
@@ -1144,14 +1147,14 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
       Arrays.asList("hello"));
 
     List<List<Tuple2<String, Long>>> expected = Arrays.asList(
-      Arrays.asList(
-              new Tuple2<String, Long>("hello", 1L),
-              new Tuple2<String, Long>("world", 1L)),
-      Arrays.asList(
-              new Tuple2<String, Long>("hello", 1L),
-              new Tuple2<String, Long>("moon", 1L)),
-      Arrays.asList(
-              new Tuple2<String, Long>("hello", 1L)));
+        Arrays.asList(
+            new Tuple2<String, Long>("hello", 1L),
+            new Tuple2<String, Long>("world", 1L)),
+        Arrays.asList(
+            new Tuple2<String, Long>("hello", 1L),
+            new Tuple2<String, Long>("moon", 1L)),
+        Arrays.asList(
+            new Tuple2<String, Long>("hello", 1L)));
 
     JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
     JavaPairDStream<String, Long> counted = stream.countByValue();
@@ -1249,17 +1252,17 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
 
     JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
         new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
-        @Override
-        public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
-          int out = 0;
-          if (state.isPresent()) {
-            out = out + state.get();
-          }
-          for (Integer v: values) {
-            out = out + v;
+          @Override
+          public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
+            int out = 0;
+            if (state.isPresent()) {
+              out = out + state.get();
+            }
+            for (Integer v : values) {
+              out = out + v;
+            }
+            return Optional.of(out);
           }
-          return Optional.of(out);
-        }
         });
     JavaTestUtils.attachTestOutputStream(updated);
     List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -1292,17 +1295,17 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
 
     JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
         new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
-        @Override
-        public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
-          int out = 0;
-          if (state.isPresent()) {
-            out = out + state.get();
-          }
-          for (Integer v: values) {
-            out = out + v;
+          @Override
+          public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
+            int out = 0;
+            if (state.isPresent()) {
+              out = out + state.get();
+            }
+            for (Integer v : values) {
+              out = out + v;
+            }
+            return Optional.of(out);
           }
-          return Optional.of(out);
-        }
         }, new HashPartitioner(1), initialRDD);
     JavaTestUtils.attachTestOutputStream(updated);
     List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -1328,7 +1331,7 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
 
     JavaPairDStream<String, Integer> reduceWindowed =
         pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(),
-          new Duration(2000), new Duration(1000));
+            new Duration(2000), new Duration(1000));
     JavaTestUtils.attachTestOutputStream(reduceWindowed);
     List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
 
@@ -1707,6 +1710,74 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
     Utils.deleteRecursively(tempDir);
   }
 
+  @SuppressWarnings("unchecked")
+  @Test
+  public void testContextGetOrCreate() throws InterruptedException {
+
+    final SparkConf conf = new SparkConf()
+        .setMaster("local[2]")
+        .setAppName("test")
+        .set("newContext", "true");
+
+    File emptyDir = Files.createTempDir();
+    emptyDir.deleteOnExit();
+    StreamingContextSuite contextSuite = new StreamingContextSuite();
+    String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint();
+    String checkpointDir = contextSuite.createValidCheckpoint();
+
+    // Function to create JavaStreamingContext without any output operations
+    // (used to detect the new context)
+    final MutableBoolean newContextCreated = new MutableBoolean(false);
+    Function0<JavaStreamingContext> creatingFunc = new Function0<JavaStreamingContext>() {
+      public JavaStreamingContext call() {
+        newContextCreated.setValue(true);
+        return new JavaStreamingContext(conf, Seconds.apply(1));
+      }
+    };
+
+    newContextCreated.setValue(false);
+    ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc);
+    Assert.assertTrue("new context not created", newContextCreated.isTrue());
+    ssc.stop();
+
+    newContextCreated.setValue(false);
+    ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc,
+        new org.apache.hadoop.conf.Configuration(), true);
+    Assert.assertTrue("new context not created", newContextCreated.isTrue());
+    ssc.stop();
+
+    newContextCreated.setValue(false);
+    ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc,
+        new org.apache.hadoop.conf.Configuration());
+    Assert.assertTrue("old context not recovered", newContextCreated.isFalse());
+    ssc.stop();
+
+    // Function to create JavaStreamingContext using existing JavaSparkContext
+    // without any output operations (used to detect the new context)
+    Function<JavaSparkContext, JavaStreamingContext> creatingFunc2 =
+        new Function<JavaSparkContext, JavaStreamingContext>() {
+          public JavaStreamingContext call(JavaSparkContext context) {
+            newContextCreated.setValue(true);
+            return new JavaStreamingContext(context, Seconds.apply(1));
+          }
+        };
+
+    JavaSparkContext sc = new JavaSparkContext(conf);
+    newContextCreated.setValue(false);
+    ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc);
+    Assert.assertTrue("new context not created", newContextCreated.isTrue());
+    ssc.stop(false);
+
+    newContextCreated.setValue(false);
+    ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true);
+    Assert.assertTrue("new context not created", newContextCreated.isTrue());
+    ssc.stop(false);
+
+    newContextCreated.setValue(false);
+    ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc);
+    Assert.assertTrue("old context not recovered", newContextCreated.isFalse());
+    ssc.stop();
+  }
 
   /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD
   @SuppressWarnings("unchecked")

http://git-wip-us.apache.org/repos/asf/spark/blob/534f2a43/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 54c3044..6b0a3f9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -430,9 +430,8 @@ class CheckpointSuite extends TestSuiteBase {
           assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3)
         }
         // Wait for a checkpoint to be written
-        val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration)
         eventually(eventuallyTimeout) {
-          assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6)
+          assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6)
         }
         ssc.stop()
         // Check that we shut down while the third batch was being processed

http://git-wip-us.apache.org/repos/asf/spark/blob/534f2a43/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 58353a5..4f19332 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.streaming
 
+import java.io.File
 import java.util.concurrent.atomic.AtomicInteger
 
+import org.apache.commons.io.FileUtils
 import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
 import org.scalatest.concurrent.Timeouts
 import org.scalatest.concurrent.Eventually._
@@ -330,6 +332,139 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
     }
   }
 
+  test("getOrCreate") {
+    val conf = new SparkConf().setMaster(master).setAppName(appName)
+
+    // Function to create StreamingContext that has a config to identify it to be new context
+    var newContextCreated = false
+    def creatingFunction(): StreamingContext = {
+      newContextCreated = true
+      new StreamingContext(conf, batchDuration)
+    }
+
+    // Call ssc.stop after a body of code
+    def testGetOrCreate(body: => Unit): Unit = {
+      newContextCreated = false
+      try {
+        body
+      } finally {
+        if (ssc != null) {
+          ssc.stop()
+        }
+        ssc = null
+      }
+    }
+
+    val emptyPath = Utils.createTempDir().getAbsolutePath()
+
+    // getOrCreate should create new context with empty path
+    testGetOrCreate {
+      ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _)
+      assert(ssc != null, "no context created")
+      assert(newContextCreated, "new context not created")
+    }
+
+    val corrutedCheckpointPath = createCorruptedCheckpoint()
+
+    // getOrCreate should throw exception with fake checkpoint file and createOnError = false
+    intercept[Exception] {
+      ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _)
+    }
+
+    // getOrCreate should throw exception with fake checkpoint file
+    intercept[Exception] {
+      ssc = StreamingContext.getOrCreate(
+        corrutedCheckpointPath, creatingFunction _, createOnError = false)
+    }
+
+    // getOrCreate should create new context with fake checkpoint file and createOnError = true
+    testGetOrCreate {
+      ssc = StreamingContext.getOrCreate(
+        corrutedCheckpointPath, creatingFunction _, createOnError = true)
+      assert(ssc != null, "no context created")
+      assert(newContextCreated, "new context not created")
+    }
+
+    val checkpointPath = createValidCheckpoint()
+
+    // getOrCreate should recover context with checkpoint path, and recover old configuration
+    testGetOrCreate {
+      ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _)
+      assert(ssc != null, "no context created")
+      assert(!newContextCreated, "old context not recovered")
+      assert(ssc.conf.get("someKey") === "someValue")
+    }
+  }
+
+  test("getOrCreate with existing SparkContext") {
+    val conf = new SparkConf().setMaster(master).setAppName(appName)
+    sc = new SparkContext(conf)
+
+    // Function to create StreamingContext that has a config to identify it to be new context
+    var newContextCreated = false
+    def creatingFunction(sparkContext: SparkContext): StreamingContext = {
+      newContextCreated = true
+      new StreamingContext(sparkContext, batchDuration)
+    }
+
+    // Call ssc.stop(stopSparkContext = false) after a body of cody
+    def testGetOrCreate(body: => Unit): Unit = {
+      newContextCreated = false
+      try {
+        body
+      } finally {
+        if (ssc != null) {
+          ssc.stop(stopSparkContext = false)
+        }
+        ssc = null
+      }
+    }
+
+    val emptyPath = Utils.createTempDir().getAbsolutePath()
+
+    // getOrCreate should create new context with empty path
+    testGetOrCreate {
+      ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true)
+      assert(ssc != null, "no context created")
+      assert(newContextCreated, "new context not created")
+      assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext")
+    }
+
+    val corrutedCheckpointPath = createCorruptedCheckpoint()
+
+    // getOrCreate should throw exception with fake checkpoint file and createOnError = false
+    intercept[Exception] {
+      ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc)
+    }
+
+    // getOrCreate should throw exception with fake checkpoint file
+    intercept[Exception] {
+      ssc = StreamingContext.getOrCreate(
+        corrutedCheckpointPath, creatingFunction _, sc, createOnError = false)
+    }
+
+    // getOrCreate should create new context with fake checkpoint file and createOnError = true
+    testGetOrCreate {
+      ssc = StreamingContext.getOrCreate(
+        corrutedCheckpointPath, creatingFunction _, sc, createOnError = true)
+      assert(ssc != null, "no context created")
+      assert(newContextCreated, "new context not created")
+      assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext")
+    }
+
+    val checkpointPath = createValidCheckpoint()
+
+    // StreamingContext.getOrCreate should recover context with checkpoint path
+    testGetOrCreate {
+      ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc)
+      assert(ssc != null, "no context created")
+      assert(!newContextCreated, "old context not recovered")
+      assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext")
+      assert(!ssc.conf.contains("someKey"),
+        "recovered StreamingContext unexpectedly has old config")
+    }
+  }
+
   test("DStream and generated RDD creation sites") {
     testPackage.test()
   }
@@ -339,6 +474,30 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
     val inputStream = new TestInputStream(s, input, 1)
     inputStream
   }
+
+  def createValidCheckpoint(): String = {
+    val testDirectory = Utils.createTempDir().getAbsolutePath()
+    val checkpointDirectory = Utils.createTempDir().getAbsolutePath()
+    val conf = new SparkConf().setMaster(master).setAppName(appName)
+    conf.set("someKey", "someValue")
+    ssc = new StreamingContext(conf, batchDuration)
+    ssc.checkpoint(checkpointDirectory)
+    ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() }
+    ssc.start()
+    eventually(timeout(10000 millis)) {
+      assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1)
+    }
+    ssc.stop()
+    checkpointDirectory
+  }
+
+  def createCorruptedCheckpoint(): String = {
+    val checkpointDirectory = Utils.createTempDir().getAbsolutePath()
+    val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000))
+    FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla")
+    assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty)
+    checkpointDirectory
+  }
 }
 
 class TestException(msg: String) extends Exception(msg)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org