You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/03/16 10:58:34 UTC

spark git commit: [SPARK-13793][CORE] PipedRDD doesn't propagate exceptions while reading parent RDD

Repository: spark
Updated Branches:
  refs/heads/master 56d88247f -> 1d95fb678


[SPARK-13793][CORE] PipedRDD doesn't propagate exceptions while reading parent RDD

## What changes were proposed in this pull request?

PipedRDD creates a child thread to read output of the parent stage and feed it to the pipe process. Used a variable to save the exception thrown in the child thread and then propagating the exception in the main thread if the variable was set.

## How was this patch tested?

- Added a unit test
- Ran all the existing tests in PipedRDDSuite and they all pass with the change
- Tested the patch with a real pipe() job, bounced the executor node which ran the parent stage to simulate a fetch failure and observed that the parent stage was re-ran.

Author: Tejas Patil <te...@fb.com>

Closes #11628 from tejasapatil/pipe_rdd.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1d95fb67
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1d95fb67
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1d95fb67

Branch: refs/heads/master
Commit: 1d95fb6785dd77d879d3d60e15320f72ab185fd3
Parents: 56d8824
Author: Tejas Patil <te...@fb.com>
Authored: Wed Mar 16 09:58:53 2016 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Wed Mar 16 09:58:53 2016 +0000

----------------------------------------------------------------------
 .../scala/org/apache/spark/rdd/PipedRDD.scala   | 97 +++++++++++++-------
 .../org/apache/spark/rdd/PipedRDDSuite.scala    | 21 +++++
 2 files changed, 86 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1d95fb67/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index afbe566..50b4184 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -22,12 +22,14 @@ import java.io.FilenameFilter
 import java.io.IOException
 import java.io.PrintWriter
 import java.util.StringTokenizer
+import java.util.concurrent.atomic.AtomicReference
 
 import scala.collection.JavaConverters._
 import scala.collection.Map
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 import scala.reflect.ClassTag
+import scala.util.control.NonFatal
 
 import org.apache.spark.{Partition, SparkEnv, TaskContext}
 import org.apache.spark.util.Utils
@@ -118,63 +120,94 @@ private[spark] class PipedRDD[T: ClassTag](
 
     val proc = pb.start()
     val env = SparkEnv.get
+    val childThreadException = new AtomicReference[Throwable](null)
 
     // Start a thread to print the process's stderr to ours
-    new Thread("stderr reader for " + command) {
-      override def run() {
-        for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
-          // scalastyle:off println
-          System.err.println(line)
-          // scalastyle:on println
+    new Thread(s"stderr reader for $command") {
+      override def run(): Unit = {
+        val err = proc.getErrorStream
+        try {
+          for (line <- Source.fromInputStream(err).getLines) {
+            // scalastyle:off println
+            System.err.println(line)
+            // scalastyle:on println
+          }
+        } catch {
+          case t: Throwable => childThreadException.set(t)
+        } finally {
+          err.close()
         }
       }
     }.start()
 
     // Start a thread to feed the process input from our parent's iterator
-    new Thread("stdin writer for " + command) {
-      override def run() {
+    new Thread(s"stdin writer for $command") {
+      override def run(): Unit = {
         TaskContext.setTaskContext(context)
         val out = new PrintWriter(proc.getOutputStream)
-
-        // scalastyle:off println
-        // input the pipe context firstly
-        if (printPipeContext != null) {
-          printPipeContext(out.println(_))
-        }
-        for (elem <- firstParent[T].iterator(split, context)) {
-          if (printRDDElement != null) {
-            printRDDElement(elem, out.println(_))
-          } else {
-            out.println(elem)
+        try {
+          // scalastyle:off println
+          // input the pipe context firstly
+          if (printPipeContext != null) {
+            printPipeContext(out.println)
+          }
+          for (elem <- firstParent[T].iterator(split, context)) {
+            if (printRDDElement != null) {
+              printRDDElement(elem, out.println)
+            } else {
+              out.println(elem)
+            }
           }
+          // scalastyle:on println
+        } catch {
+          case t: Throwable => childThreadException.set(t)
+        } finally {
+          out.close()
         }
-        // scalastyle:on println
-        out.close()
       }
     }.start()
 
     // Return an iterator that read lines from the process's stdout
     val lines = Source.fromInputStream(proc.getInputStream).getLines()
     new Iterator[String] {
-      def next(): String = lines.next()
-      def hasNext: Boolean = {
-        if (lines.hasNext) {
+      def next(): String = {
+        if (!hasNext()) {
+          throw new NoSuchElementException()
+        }
+        lines.next()
+      }
+
+      def hasNext(): Boolean = {
+        val result = if (lines.hasNext) {
           true
         } else {
           val exitStatus = proc.waitFor()
+          cleanup()
           if (exitStatus != 0) {
-            throw new Exception("Subprocess exited with status " + exitStatus)
+            throw new IllegalStateException(s"Subprocess exited with status $exitStatus")
           }
+          false
+        }
+        propagateChildException()
+        result
+      }
 
-          // cleanup task working directory if used
-          if (workInTaskDirectory) {
-            scala.util.control.Exception.ignoring(classOf[IOException]) {
-              Utils.deleteRecursively(new File(taskDirectory))
-            }
-            logDebug("Removed task working directory " + taskDirectory)
+      private def cleanup(): Unit = {
+        // cleanup task working directory if used
+        if (workInTaskDirectory) {
+          scala.util.control.Exception.ignoring(classOf[IOException]) {
+            Utils.deleteRecursively(new File(taskDirectory))
           }
+          logDebug(s"Removed task working directory $taskDirectory")
+        }
+      }
 
-          false
+      private def propagateChildException(): Unit = {
+        val t = childThreadException.get()
+        if (t != null) {
+          proc.destroy()
+          cleanup()
+          throw t
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/1d95fb67/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 1eebc92..d13da38 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -50,6 +50,27 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
     }
   }
 
+  test("failure in iterating over pipe input") {
+    if (testCommandAvailable("cat")) {
+      val nums =
+        sc.makeRDD(Array(1, 2, 3, 4), 2)
+          .mapPartitionsWithIndex((index, iterator) => {
+            new Iterator[Int] {
+              def hasNext = true
+              def next() = {
+                throw new SparkException("Exception to simulate bad scenario")
+              }
+            }
+          })
+
+      val piped = nums.pipe(Seq("cat"))
+
+      intercept[SparkException] {
+        piped.collect()
+      }
+    }
+  }
+
   test("advanced pipe") {
     if (testCommandAvailable("cat")) {
       val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org