You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/05/07 18:48:37 UTC

git commit: SPARK-1579: Clean up PythonRDD and avoid swallowing IOExceptions

Repository: spark
Updated Branches:
  refs/heads/master 967635a24 -> 3308722ca


SPARK-1579: Clean up PythonRDD and avoid swallowing IOExceptions

This patch includes several cleanups to PythonRDD, focused around fixing [SPARK-1579](https://issues.apache.org/jira/browse/SPARK-1579) cleanly. Listed in order of approximate importance:

- The Python daemon waits for Spark to close the socket before exiting,
  in order to avoid causing spurious IOExceptions in Spark's
  `PythonRDD::WriterThread`.
- Removes the Python Monitor Thread, which polled for task cancellations
  in order to kill the Python worker. Instead, we do this in the
  onCompleteCallback, since this is guaranteed to be called during
  cancellation.
- Adds a "completed" variable to TaskContext to avoid the issue noted in
  [SPARK-1019](https://issues.apache.org/jira/browse/SPARK-1019), where onCompleteCallbacks may be execution-order dependent.
  Along with this, I removed the "context.interrupted = true" flag in
  the onCompleteCallback.
- Extracts PythonRDD::WriterThread to its own class.

Since this patch provides an alternative solution to [SPARK-1019](https://issues.apache.org/jira/browse/SPARK-1019), I did test it with

```
sc.textFile("latlon.tsv").take(5)
```

many times without error.

Additionally, in order to test the unswallowed exceptions, I performed

```
sc.textFile("s3n://<big file>").count()
```

and cut my internet during execution. Prior to this patch, we got the "stdin writer exited early" message, which was unhelpful. Now, we get the SocketExceptions propagated through Spark to the user and get proper (though unsuccessful) task retries.

Author: Aaron Davidson <aa...@databricks.com>

Closes #640 from aarondav/pyspark-io and squashes the following commits:

b391ff8 [Aaron Davidson] Detect "clean socket shutdowns" and stop waiting on the socket
c0c49da [Aaron Davidson] SPARK-1579: Clean up PythonRDD and avoid swallowing IOExceptions


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3308722c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3308722c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3308722c

Branch: refs/heads/master
Commit: 3308722ca03f2bfa792e9a2cff9c894b967983d9
Parents: 967635a
Author: Aaron Davidson <aa...@databricks.com>
Authored: Wed May 7 09:48:31 2014 -0700
Committer: Patrick Wendell <pw...@gmail.com>
Committed: Wed May 7 09:48:31 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/TaskContext.scala    |   5 +
 .../org/apache/spark/api/python/PythonRDD.scala | 217 ++++++++++---------
 .../apache/spark/scheduler/ShuffleMapTask.scala |  10 +-
 python/pyspark/context.py                       |   2 +-
 python/pyspark/daemon.py                        |  14 +-
 5 files changed, 141 insertions(+), 107 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3308722c/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index dc012cc..fc48127 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -42,9 +42,13 @@ class TaskContext(
   // List of callback functions to execute when the task completes.
   @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
 
+  // Set to true when the task is completed, before the onCompleteCallbacks are executed.
+  @volatile var completed: Boolean = false
+
   /**
    * Add a callback function to be executed on task completion. An example use
    * is for HadoopRDD to register a callback to close the input stream.
+   * Will be called in any situation - success, failure, or cancellation.
    * @param f Callback function.
    */
   def addOnCompleteCallback(f: () => Unit) {
@@ -52,6 +56,7 @@ class TaskContext(
   }
 
   def executeOnCompleteCallbacks() {
+    completed = true
     // Process complete callbacks in the reverse order of registration
     onCompleteCallbacks.reverse.foreach{_()}
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3308722c/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 6140700..fecd976 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -56,122 +56,37 @@ private[spark] class PythonRDD[T: ClassTag](
     val env = SparkEnv.get
     val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
 
-    // Ensure worker socket is closed on task completion. Closing sockets is idempotent.
-    context.addOnCompleteCallback(() =>
+    // Start a thread to feed the process input from our parent's iterator
+    val writerThread = new WriterThread(env, worker, split, context)
+
+    context.addOnCompleteCallback { () =>
+      writerThread.shutdownOnTaskCompletion()
+
+      // Cleanup the worker socket. This will also cause the Python worker to exit.
       try {
         worker.close()
       } catch {
         case e: Exception => logWarning("Failed to close worker socket", e)
       }
-    )
-
-    @volatile var readerException: Exception = null
-
-    // Start a thread to feed the process input from our parent's iterator
-    new Thread("stdin writer for " + pythonExec) {
-      override def run() {
-        try {
-          SparkEnv.set(env)
-          val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
-          val dataOut = new DataOutputStream(stream)
-          // Partition index
-          dataOut.writeInt(split.index)
-          // sparkFilesDir
-          PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
-          // Broadcast variables
-          dataOut.writeInt(broadcastVars.length)
-          for (broadcast <- broadcastVars) {
-            dataOut.writeLong(broadcast.id)
-            dataOut.writeInt(broadcast.value.length)
-            dataOut.write(broadcast.value)
-          }
-          // Python includes (*.zip and *.egg files)
-          dataOut.writeInt(pythonIncludes.length)
-          for (include <- pythonIncludes) {
-            PythonRDD.writeUTF(include, dataOut)
-          }
-          dataOut.flush()
-          // Serialized command:
-          dataOut.writeInt(command.length)
-          dataOut.write(command)
-          // Data values
-          PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
-          dataOut.flush()
-          worker.shutdownOutput()
-        } catch {
-
-          case e: java.io.FileNotFoundException =>
-            readerException = e
-            Try(worker.shutdownOutput()) // kill Python worker process
-
-          case e: IOException =>
-            // This can happen for legitimate reasons if the Python code stops returning data
-            // before we are done passing elements through, e.g., for take(). Just log a message to
-            // say it happened (as it could also be hiding a real IOException from a data source).
-            logInfo("stdin writer to Python finished early (may not be an error)", e)
-
-          case e: Exception =>
-            // We must avoid throwing exceptions here, because the thread uncaught exception handler
-            // will kill the whole executor (see Executor).
-            readerException = e
-            Try(worker.shutdownOutput()) // kill Python worker process
-        }
-      }
-    }.start()
-
-    // Necessary to distinguish between a task that has failed and a task that is finished
-    @volatile var complete: Boolean = false
-
-    // It is necessary to have a monitor thread for python workers if the user cancels with
-    // interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
-    // threads can block indefinitely.
-    new Thread(s"Worker Monitor for $pythonExec") {
-      override def run() {
-        // Kill the worker if it is interrupted or completed
-        // When a python task completes, the context is always set to interupted
-        while (!context.interrupted) {
-          Thread.sleep(2000)
-        }
-        if (!complete) {
-          try {
-            logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
-            env.destroyPythonWorker(pythonExec, envVars.toMap)
-          } catch {
-            case e: Exception =>
-              logError("Exception when trying to kill worker", e)
-          }
-        }
-      }
-    }.start()
-
-    /*
-     * Partial fix for SPARK-1019: Attempts to stop reading the input stream since
-     * other completion callbacks might invalidate the input. Because interruption
-     * is not synchronous this still leaves a potential race where the interruption is
-     * processed only after the stream becomes invalid.
-     */
-    context.addOnCompleteCallback{ () =>
-      complete = true // Indicate that the task has completed successfully
-      context.interrupted = true
     }
 
+    writerThread.start()
+    new MonitorThread(env, worker, context).start()
+
     // Return an iterator that read lines from the process's stdout
     val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
     val stdoutIterator = new Iterator[Array[Byte]] {
       def next(): Array[Byte] = {
         val obj = _nextObj
         if (hasNext) {
-          // FIXME: can deadlock if worker is waiting for us to
-          // respond to current message (currently irrelevant because
-          // output is shutdown before we read any input)
           _nextObj = read()
         }
         obj
       }
 
       private def read(): Array[Byte] = {
-        if (readerException != null) {
-          throw readerException
+        if (writerThread.exception.isDefined) {
+          throw writerThread.exception.get
         }
         try {
           stream.readInt() match {
@@ -190,13 +105,14 @@ private[spark] class PythonRDD[T: ClassTag](
               val total = finishTime - startTime
               logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
                 init, finish))
-              read
+              read()
             case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
               // Signals that an exception has been thrown in python
               val exLength = stream.readInt()
               val obj = new Array[Byte](exLength)
               stream.readFully(obj)
-              throw new PythonException(new String(obj, "utf-8"), readerException)
+              throw new PythonException(new String(obj, "utf-8"),
+                writerThread.exception.getOrElse(null))
             case SpecialLengths.END_OF_DATA_SECTION =>
               // We've finished the data section of the output, but we can still
               // read some accumulator updates:
@@ -210,10 +126,15 @@ private[spark] class PythonRDD[T: ClassTag](
               Array.empty[Byte]
           }
         } catch {
-          case e: Exception if readerException != null =>
+
+          case e: Exception if context.interrupted =>
+            logDebug("Exception thrown after task interruption", e)
+            throw new TaskKilledException
+
+          case e: Exception if writerThread.exception.isDefined =>
             logError("Python worker exited unexpectedly (crashed)", e)
-            logError("Python crash may have been caused by prior exception:", readerException)
-            throw readerException
+            logError("This may have been caused by a prior exception:", writerThread.exception.get)
+            throw writerThread.exception.get
 
           case eof: EOFException =>
             throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
@@ -224,10 +145,100 @@ private[spark] class PythonRDD[T: ClassTag](
 
       def hasNext = _nextObj.length != 0
     }
-    stdoutIterator
+    new InterruptibleIterator(context, stdoutIterator)
   }
 
   val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+
+  /**
+   * The thread responsible for writing the data from the PythonRDD's parent iterator to the
+   * Python process.
+   */
+  class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
+    extends Thread(s"stdout writer for $pythonExec") {
+
+    @volatile private var _exception: Exception = null
+
+    setDaemon(true)
+
+    /** Contains the exception thrown while writing the parent iterator to the Python process. */
+    def exception: Option[Exception] = Option(_exception)
+
+    /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
+    def shutdownOnTaskCompletion() {
+      assert(context.completed)
+      this.interrupt()
+    }
+
+    override def run() {
+      try {
+        SparkEnv.set(env)
+        val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+        val dataOut = new DataOutputStream(stream)
+        // Partition index
+        dataOut.writeInt(split.index)
+        // sparkFilesDir
+        PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
+        // Broadcast variables
+        dataOut.writeInt(broadcastVars.length)
+        for (broadcast <- broadcastVars) {
+          dataOut.writeLong(broadcast.id)
+          dataOut.writeInt(broadcast.value.length)
+          dataOut.write(broadcast.value)
+        }
+        // Python includes (*.zip and *.egg files)
+        dataOut.writeInt(pythonIncludes.length)
+        for (include <- pythonIncludes) {
+          PythonRDD.writeUTF(include, dataOut)
+        }
+        dataOut.flush()
+        // Serialized command:
+        dataOut.writeInt(command.length)
+        dataOut.write(command)
+        // Data values
+        PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+        dataOut.flush()
+      } catch {
+        case e: Exception if context.completed || context.interrupted =>
+          logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+
+        case e: Exception =>
+          // We must avoid throwing exceptions here, because the thread uncaught exception handler
+          // will kill the whole executor (see org.apache.spark.executor.Executor).
+          _exception = e
+      } finally {
+        Try(worker.shutdownOutput()) // kill Python worker process
+      }
+    }
+  }
+
+  /**
+   * It is necessary to have a monitor thread for python workers if the user cancels with
+   * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
+   * threads can block indefinitely.
+   */
+  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
+    extends Thread(s"Worker Monitor for $pythonExec") {
+
+    setDaemon(true)
+
+    override def run() {
+      // Kill the worker if it is interrupted, checking until task completion.
+      // TODO: This has a race condition if interruption occurs, as completed may still become true.
+      while (!context.interrupted && !context.completed) {
+        Thread.sleep(2000)
+      }
+      if (!context.completed) {
+        try {
+          logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
+          env.destroyPythonWorker(pythonExec, envVars.toMap)
+        } catch {
+          case e: Exception =>
+            logError("Exception when trying to kill worker", e)
+        }
+      }
+    }
+  }
 }
 
 /** Thrown for exceptions in user Python code. */

http://git-wip-us.apache.org/repos/asf/spark/blob/3308722c/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 02b62de..2259df0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark.scheduler
 
+import scala.language.existentials
+
 import java.io._
 import java.util.zip.{GZIPInputStream, GZIPOutputStream}
 
 import scala.collection.mutable.HashMap
-import scala.language.existentials
+import scala.util.Try
 
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
@@ -196,7 +198,11 @@ private[spark] class ShuffleMapTask(
     } finally {
       // Release the writers back to the shuffle block manager.
       if (shuffle != null && shuffle.writers != null) {
-        shuffle.releaseWriters(success)
+        try {
+          shuffle.releaseWriters(success)
+        } catch {
+          case e: Exception => logError("Failed to release shuffle writers", e)
+        }
       }
       // Execute the callbacks on task completion.
       context.executeOnCompleteCallbacks()

http://git-wip-us.apache.org/repos/asf/spark/blob/3308722c/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index c7dc85e..cac133d 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -453,7 +453,7 @@ class SparkContext(object):
         >>> lock = threading.Lock()
         >>> def map_func(x):
         ...     sleep(100)
-        ...     return x * x
+        ...     raise Exception("Task should have been cancelled")
         >>> def start_job(x):
         ...     global result
         ...     try:

http://git-wip-us.apache.org/repos/asf/spark/blob/3308722c/python/pyspark/daemon.py
----------------------------------------------------------------------
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index eb18ec0..b2f226a 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -74,6 +74,17 @@ def worker(listen_sock):
                 raise
     signal.signal(SIGCHLD, handle_sigchld)
 
+    # Blocks until the socket is closed by draining the input stream
+    # until it raises an exception or returns EOF.
+    def waitSocketClose(sock):
+        try:
+            while True:
+                # Empty string is returned upon EOF (and only then).
+                if sock.recv(4096) == '':
+                    return
+        except:
+            pass
+
     # Handle clients
     while not should_exit():
         # Wait until a client arrives or we have to exit
@@ -105,7 +116,8 @@ def worker(listen_sock):
                     exit_code = exc.code
                 finally:
                     outfile.flush()
-                    sock.close()
+                    # The Scala side will close the socket upon task completion.
+                    waitSocketClose(sock)
                     os._exit(compute_real_exit_code(exit_code))
             else:
                 sock.close()