You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/07/09 02:32:14 UTC

[spark] branch branch-3.2 updated: [SPARK-36062][PYTHON] Try to capture faulthanlder when a Python worker crashes

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 55111ca  [SPARK-36062][PYTHON] Try to capture faulthanlder when a Python worker crashes
55111ca is described below

commit 55111cafd122d6c218be675260a54e1863d57394
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Fri Jul 9 11:30:39 2021 +0900

    [SPARK-36062][PYTHON] Try to capture faulthanlder when a Python worker crashes
    
    ### What changes were proposed in this pull request?
    
    Try to capture the error message from the `faulthandler` when the Python worker crashes.
    
    ### Why are the changes needed?
    
    Currently, we just see an error message saying `"exited unexpectedly (crashed)"` when the UDFs causes the Python worker to crash by like segmentation fault.
    We should take advantage of [`faulthandler`](https://docs.python.org/3/library/faulthandler.html) and try to capture the error message from the `faulthandler`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, when a Spark config `spark.python.worker.faulthandler.enabled` is `true`, the stack trace will be seen in the error message when the Python worker crashes.
    
    ```py
    >>> def f():
    ...   import ctypes
    ...   ctypes.string_at(0)
    ...
    >>> sc.parallelize([1]).map(lambda x: f()).count()
    ```
    
    ```
    org.apache.spark.SparkException: Python worker exited unexpectedly (crashed): Fatal Python error: Segmentation fault
    
    Current thread 0x000000010965b5c0 (most recent call first):
      File "/.../ctypes/__init__.py", line 525 in string_at
      File "<stdin>", line 3 in f
      File "<stdin>", line 1 in <lambda>
    ...
    ```
    
    ### How was this patch tested?
    
    Added some tests, and manually.
    
    Closes #33273 from ueshin/issues/SPARK-36062/faulthandler.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
    (cherry picked from commit 115b8a180f41fe957341b0725c3f34499267bb92)
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../src/main/scala/org/apache/spark/SparkEnv.scala |  4 ++-
 .../org/apache/spark/api/python/PythonRunner.scala | 33 ++++++++++++++++++++--
 .../spark/api/python/PythonWorkerFactory.scala     | 20 ++++++++-----
 .../org/apache/spark/internal/config/Python.scala  |  8 ++++++
 python/pyspark/tests/test_worker.py                | 29 +++++++++++++++++++
 python/pyspark/worker.py                           | 15 ++++++++++
 .../sql/execution/python/PythonArrowOutput.scala   |  4 ++-
 .../sql/execution/python/PythonUDFRunner.scala     |  4 ++-
 8 files changed, 104 insertions(+), 13 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index ed8dc43..ee50a8f 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -113,7 +113,9 @@ class SparkEnv (
   }
 
   private[spark]
-  def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
+  def createPythonWorker(
+      pythonExec: String,
+      envVars: Map[String, String]): (java.net.Socket, Option[Int]) = {
     synchronized {
       val key = (pythonExec, envVars)
       pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 6e2b6ad..db0e100 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -21,6 +21,7 @@ import java.io._
 import java.net._
 import java.nio.charset.StandardCharsets
 import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.file.{Files => JavaFiles, Path}
 import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.AtomicBoolean
 
@@ -65,6 +66,15 @@ private[spark] object PythonEvalType {
   }
 }
 
+private object BasePythonRunner {
+
+  private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")
+
+  private def faultHandlerLogPath(pid: Int): Path = {
+    new File(faultHandlerLogDir, pid.toString).toPath
+  }
+}
+
 /**
  * A helper class to run Python mapPartition/UDFs in Spark.
  *
@@ -83,6 +93,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
   protected val bufferSize: Int = conf.get(BUFFER_SIZE)
   protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
   private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
+  private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
   protected val simplifiedTraceback: Boolean = false
 
   // All the Python functions should have the same exec, version and envvars.
@@ -143,7 +154,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     }
     envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
     envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
-    val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
+    if (faultHandlerEnabled) {
+      envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString)
+    }
+
+    val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
+      pythonExec, envVars.asScala.toMap)
     // Whether is the worker released into idle pool or closed. When any codes try to release or
     // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
     // sure there is only one winner that is going to release or close the worker.
@@ -180,7 +196,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
 
     val stdoutIterator = newReaderIterator(
-      stream, writerThread, startTime, env, worker, releasedOrClosed, context)
+      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context)
     new InterruptibleIterator(context, stdoutIterator)
   }
 
@@ -197,6 +213,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
+      pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[OUT]
 
@@ -468,6 +485,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
+      pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext)
     extends Iterator[OUT] {
@@ -556,6 +574,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         logError("This may have been caused by a prior exception:", writerThread.exception.get)
         throw writerThread.exception.get
 
+      case eof: EOFException if faultHandlerEnabled && pid.isDefined &&
+          JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) =>
+        val path = BasePythonRunner.faultHandlerLogPath(pid.get)
+        val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n"
+        JavaFiles.deleteIfExists(path)
+        throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof)
+
       case eof: EOFException =>
         throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
     }
@@ -654,9 +679,11 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
+      pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[Array[Byte]] = {
-    new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
+    new ReaderIterator(
+      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
 
       protected override def read(): Array[Byte] = {
         if (writerThread.exception.isDefined) {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index df236ba..7b2c36b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -95,11 +95,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
     envVars.getOrElse("PYTHONPATH", ""),
     sys.env.getOrElse("PYTHONPATH", ""))
 
-  def create(): Socket = {
+  def create(): (Socket, Option[Int]) = {
     if (useDaemon) {
       self.synchronized {
         if (idleWorkers.nonEmpty) {
-          return idleWorkers.dequeue()
+          val worker = idleWorkers.dequeue()
+          return (worker, daemonWorkers.get(worker))
         }
       }
       createThroughDaemon()
@@ -113,9 +114,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
    * processes itself to avoid the high cost of forking from Java. This currently only works
    * on UNIX-based systems.
    */
-  private def createThroughDaemon(): Socket = {
+  private def createThroughDaemon(): (Socket, Option[Int]) = {
 
-    def createSocket(): Socket = {
+    def createSocket(): (Socket, Option[Int]) = {
       val socket = new Socket(daemonHost, daemonPort)
       val pid = new DataInputStream(socket.getInputStream).readInt()
       if (pid < 0) {
@@ -124,7 +125,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
 
       authHelper.authToServer(socket)
       daemonWorkers.put(socket, pid)
-      socket
+      (socket, Some(pid))
     }
 
     self.synchronized {
@@ -148,7 +149,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
   /**
    * Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
    */
-  private def createSimpleWorker(): Socket = {
+  private def createSimpleWorker(): (Socket, Option[Int]) = {
     var serverSocket: ServerSocket = null
     try {
       serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
@@ -173,10 +174,15 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
       try {
         val socket = serverSocket.accept()
         authHelper.authClient(socket)
+        // TODO: When we drop JDK 8, we can just use worker.pid()
+        val pid = new DataInputStream(socket.getInputStream).readInt()
+        if (pid < 0) {
+          throw new IllegalStateException("Python failed to launch worker with code " + pid)
+        }
         self.synchronized {
           simpleWorkers.put(socket, worker)
         }
-        return socket
+        return (socket, Some(pid))
       } catch {
         case e: Exception =>
           throw new SparkException("Python worker failed to connect back.", e)
diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
index 348a33e..5e026fd 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
@@ -56,4 +56,12 @@ private[spark] object Python {
     .version("3.1.0")
     .timeConf(TimeUnit.SECONDS)
     .createWithDefaultString("15s")
+
+  val PYTHON_WORKER_FAULTHANLDER_ENABLED = ConfigBuilder("spark.python.worker.faulthandler.enabled")
+    .doc("When true, Python workers set up the faulthandler for the case when the Python worker " +
+      "exits unexpectedly (crashes), and shows the stack trace of the moment the Python worker " +
+      "crashes in the error message if captured successfully.")
+    .version("3.2.0")
+    .booleanConf
+    .createWithDefault(false)
 }
diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py
index 120c5e3..a77d38e 100644
--- a/python/pyspark/tests/test_worker.py
+++ b/python/pyspark/tests/test_worker.py
@@ -206,6 +206,35 @@ class WorkerMemoryTest(unittest.TestCase):
     def tearDown(self):
         self.sc.stop()
 
+
+class WorkerSegfaultTest(ReusedPySparkTestCase):
+
+    @classmethod
+    def conf(cls):
+        _conf = super(WorkerSegfaultTest, cls).conf()
+        _conf.set("spark.python.worker.faulthandler.enabled", "true")
+        return _conf
+
+    def test_python_segfault(self):
+        try:
+            def f():
+                import ctypes
+                ctypes.string_at(0)
+
+            self.sc.parallelize([1]).map(lambda x: f()).count()
+        except Py4JJavaError as e:
+            self.assertRegex(str(e), "Segmentation fault")
+
+
+class WorkerSegfaultNonDaemonTest(WorkerSegfaultTest):
+
+    @classmethod
+    def conf(cls):
+        _conf = super(WorkerSegfaultNonDaemonTest, cls).conf()
+        _conf.set("spark.python.use.daemon", "false")
+        return _conf
+
+
 if __name__ == "__main__":
     import unittest
     from pyspark.tests.test_worker import *  # noqa: F401
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 023a655..a13717f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,6 +31,7 @@ except ImportError:
     has_resource_module = False
 import traceback
 import warnings
+import faulthandler
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
@@ -463,7 +464,13 @@ def read_udfs(pickleSer, infile, eval_type):
 
 
 def main(infile, outfile):
+    faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
     try:
+        if faulthandler_log_path:
+            faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
+            faulthandler_log_file = open(faulthandler_log_path, "w")
+            faulthandler.enable(file=faulthandler_log_file)
+
         boot_time = time.time()
         split_index = read_int(infile)
         if split_index == -1:  # for unit tests
@@ -636,6 +643,11 @@ def main(infile, outfile):
             print("PySpark worker failed with exception:", file=sys.stderr)
             print(traceback.format_exc(), file=sys.stderr)
         sys.exit(-1)
+    finally:
+        if faulthandler_log_path:
+            faulthandler.disable()
+            faulthandler_log_file.close()
+            os.remove(faulthandler_log_path)
     finish_time = time.time()
     report_times(outfile, boot_time, init_time, finish_time)
     write_long(shuffle.MemoryBytesSpilled, outfile)
@@ -661,4 +673,7 @@ if __name__ == '__main__':
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
     (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    # TODO: Remove thw following two lines and use `Process.pid()` when we drop JDK 8.
+    write_int(os.getpid(), sock_file)
+    sock_file.flush()
     main(sock_file, sock_file)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index bb35306..00bab1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -43,10 +43,12 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
+      pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[ColumnarBatch] = {
 
-    new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
+    new ReaderIterator(
+      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
 
       private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
         s"stdin reader for $pythonExec", 0, Long.MaxValue)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index d9fe072..d1109d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -62,9 +62,11 @@ class PythonUDFRunner(
       startTime: Long,
       env: SparkEnv,
       worker: Socket,
+      pid: Option[Int],
       releasedOrClosed: AtomicBoolean,
       context: TaskContext): Iterator[Array[Byte]] = {
-    new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
+    new ReaderIterator(
+      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
 
       protected override def read(): Array[Byte] = {
         if (writerThread.exception.isDefined) {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org