You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/07/09 02:32:14 UTC
[spark] branch branch-3.2 updated: [SPARK-36062][PYTHON] Try to
capture faulthanlder when a Python worker crashes
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 55111ca [SPARK-36062][PYTHON] Try to capture faulthanlder when a Python worker crashes
55111ca is described below
commit 55111cafd122d6c218be675260a54e1863d57394
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Fri Jul 9 11:30:39 2021 +0900
[SPARK-36062][PYTHON] Try to capture faulthanlder when a Python worker crashes
### What changes were proposed in this pull request?
Try to capture the error message from the `faulthandler` when the Python worker crashes.
### Why are the changes needed?
Currently, we just see an error message saying `"exited unexpectedly (crashed)"` when the UDFs causes the Python worker to crash by like segmentation fault.
We should take advantage of [`faulthandler`](https://docs.python.org/3/library/faulthandler.html) and try to capture the error message from the `faulthandler`.
### Does this PR introduce _any_ user-facing change?
Yes, when a Spark config `spark.python.worker.faulthandler.enabled` is `true`, the stack trace will be seen in the error message when the Python worker crashes.
```py
>>> def f():
... import ctypes
... ctypes.string_at(0)
...
>>> sc.parallelize([1]).map(lambda x: f()).count()
```
```
org.apache.spark.SparkException: Python worker exited unexpectedly (crashed): Fatal Python error: Segmentation fault
Current thread 0x000000010965b5c0 (most recent call first):
File "/.../ctypes/__init__.py", line 525 in string_at
File "<stdin>", line 3 in f
File "<stdin>", line 1 in <lambda>
...
```
### How was this patch tested?
Added some tests, and manually.
Closes #33273 from ueshin/issues/SPARK-36062/faulthandler.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 115b8a180f41fe957341b0725c3f34499267bb92)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../src/main/scala/org/apache/spark/SparkEnv.scala | 4 ++-
.../org/apache/spark/api/python/PythonRunner.scala | 33 ++++++++++++++++++++--
.../spark/api/python/PythonWorkerFactory.scala | 20 ++++++++-----
.../org/apache/spark/internal/config/Python.scala | 8 ++++++
python/pyspark/tests/test_worker.py | 29 +++++++++++++++++++
python/pyspark/worker.py | 15 ++++++++++
.../sql/execution/python/PythonArrowOutput.scala | 4 ++-
.../sql/execution/python/PythonUDFRunner.scala | 4 ++-
8 files changed, 104 insertions(+), 13 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index ed8dc43..ee50a8f 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -113,7 +113,9 @@ class SparkEnv (
}
private[spark]
- def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
+ def createPythonWorker(
+ pythonExec: String,
+ envVars: Map[String, String]): (java.net.Socket, Option[Int]) = {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 6e2b6ad..db0e100 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -21,6 +21,7 @@ import java.io._
import java.net._
import java.nio.charset.StandardCharsets
import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.file.{Files => JavaFiles, Path}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
@@ -65,6 +66,15 @@ private[spark] object PythonEvalType {
}
}
+private object BasePythonRunner {
+
+ private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")
+
+ private def faultHandlerLogPath(pid: Int): Path = {
+ new File(faultHandlerLogDir, pid.toString).toPath
+ }
+}
+
/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
@@ -83,6 +93,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
+ private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
protected val simplifiedTraceback: Boolean = false
// All the Python functions should have the same exec, version and envvars.
@@ -143,7 +154,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
- val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
+ if (faultHandlerEnabled) {
+ envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString)
+ }
+
+ val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
+ pythonExec, envVars.asScala.toMap)
// Whether is the worker released into idle pool or closed. When any codes try to release or
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
// sure there is only one winner that is going to release or close the worker.
@@ -180,7 +196,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
val stdoutIterator = newReaderIterator(
- stream, writerThread, startTime, env, worker, releasedOrClosed, context)
+ stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context)
new InterruptibleIterator(context, stdoutIterator)
}
@@ -197,6 +213,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: Socket,
+ pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT]
@@ -468,6 +485,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: Socket,
+ pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext)
extends Iterator[OUT] {
@@ -556,6 +574,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
logError("This may have been caused by a prior exception:", writerThread.exception.get)
throw writerThread.exception.get
+ case eof: EOFException if faultHandlerEnabled && pid.isDefined &&
+ JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) =>
+ val path = BasePythonRunner.faultHandlerLogPath(pid.get)
+ val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n"
+ JavaFiles.deleteIfExists(path)
+ throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof)
+
case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
@@ -654,9 +679,11 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
startTime: Long,
env: SparkEnv,
worker: Socket,
+ pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
- new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
+ new ReaderIterator(
+ stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
protected override def read(): Array[Byte] = {
if (writerThread.exception.isDefined) {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index df236ba..7b2c36b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -95,11 +95,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
envVars.getOrElse("PYTHONPATH", ""),
sys.env.getOrElse("PYTHONPATH", ""))
- def create(): Socket = {
+ def create(): (Socket, Option[Int]) = {
if (useDaemon) {
self.synchronized {
if (idleWorkers.nonEmpty) {
- return idleWorkers.dequeue()
+ val worker = idleWorkers.dequeue()
+ return (worker, daemonWorkers.get(worker))
}
}
createThroughDaemon()
@@ -113,9 +114,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
* processes itself to avoid the high cost of forking from Java. This currently only works
* on UNIX-based systems.
*/
- private def createThroughDaemon(): Socket = {
+ private def createThroughDaemon(): (Socket, Option[Int]) = {
- def createSocket(): Socket = {
+ def createSocket(): (Socket, Option[Int]) = {
val socket = new Socket(daemonHost, daemonPort)
val pid = new DataInputStream(socket.getInputStream).readInt()
if (pid < 0) {
@@ -124,7 +125,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
authHelper.authToServer(socket)
daemonWorkers.put(socket, pid)
- socket
+ (socket, Some(pid))
}
self.synchronized {
@@ -148,7 +149,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
/**
* Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
*/
- private def createSimpleWorker(): Socket = {
+ private def createSimpleWorker(): (Socket, Option[Int]) = {
var serverSocket: ServerSocket = null
try {
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
@@ -173,10 +174,15 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
try {
val socket = serverSocket.accept()
authHelper.authClient(socket)
+ // TODO: When we drop JDK 8, we can just use worker.pid()
+ val pid = new DataInputStream(socket.getInputStream).readInt()
+ if (pid < 0) {
+ throw new IllegalStateException("Python failed to launch worker with code " + pid)
+ }
self.synchronized {
simpleWorkers.put(socket, worker)
}
- return socket
+ return (socket, Some(pid))
} catch {
case e: Exception =>
throw new SparkException("Python worker failed to connect back.", e)
diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
index 348a33e..5e026fd 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
@@ -56,4 +56,12 @@ private[spark] object Python {
.version("3.1.0")
.timeConf(TimeUnit.SECONDS)
.createWithDefaultString("15s")
+
+ val PYTHON_WORKER_FAULTHANLDER_ENABLED = ConfigBuilder("spark.python.worker.faulthandler.enabled")
+ .doc("When true, Python workers set up the faulthandler for the case when the Python worker " +
+ "exits unexpectedly (crashes), and shows the stack trace of the moment the Python worker " +
+ "crashes in the error message if captured successfully.")
+ .version("3.2.0")
+ .booleanConf
+ .createWithDefault(false)
}
diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py
index 120c5e3..a77d38e 100644
--- a/python/pyspark/tests/test_worker.py
+++ b/python/pyspark/tests/test_worker.py
@@ -206,6 +206,35 @@ class WorkerMemoryTest(unittest.TestCase):
def tearDown(self):
self.sc.stop()
+
+class WorkerSegfaultTest(ReusedPySparkTestCase):
+
+ @classmethod
+ def conf(cls):
+ _conf = super(WorkerSegfaultTest, cls).conf()
+ _conf.set("spark.python.worker.faulthandler.enabled", "true")
+ return _conf
+
+ def test_python_segfault(self):
+ try:
+ def f():
+ import ctypes
+ ctypes.string_at(0)
+
+ self.sc.parallelize([1]).map(lambda x: f()).count()
+ except Py4JJavaError as e:
+ self.assertRegex(str(e), "Segmentation fault")
+
+
+class WorkerSegfaultNonDaemonTest(WorkerSegfaultTest):
+
+ @classmethod
+ def conf(cls):
+ _conf = super(WorkerSegfaultNonDaemonTest, cls).conf()
+ _conf.set("spark.python.use.daemon", "false")
+ return _conf
+
+
if __name__ == "__main__":
import unittest
from pyspark.tests.test_worker import * # noqa: F401
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 023a655..a13717f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,6 +31,7 @@ except ImportError:
has_resource_module = False
import traceback
import warnings
+import faulthandler
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
@@ -463,7 +464,13 @@ def read_udfs(pickleSer, infile, eval_type):
def main(infile, outfile):
+ faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
+ if faulthandler_log_path:
+ faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
+ faulthandler_log_file = open(faulthandler_log_path, "w")
+ faulthandler.enable(file=faulthandler_log_file)
+
boot_time = time.time()
split_index = read_int(infile)
if split_index == -1: # for unit tests
@@ -636,6 +643,11 @@ def main(infile, outfile):
print("PySpark worker failed with exception:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
sys.exit(-1)
+ finally:
+ if faulthandler_log_path:
+ faulthandler.disable()
+ faulthandler_log_file.close()
+ os.remove(faulthandler_log_path)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
write_long(shuffle.MemoryBytesSpilled, outfile)
@@ -661,4 +673,7 @@ if __name__ == '__main__':
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ # TODO: Remove thw following two lines and use `Process.pid()` when we drop JDK 8.
+ write_int(os.getpid(), sock_file)
+ sock_file.flush()
main(sock_file, sock_file)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index bb35306..00bab1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -43,10 +43,12 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
startTime: Long,
env: SparkEnv,
worker: Socket,
+ pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[ColumnarBatch] = {
- new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
+ new ReaderIterator(
+ stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdin reader for $pythonExec", 0, Long.MaxValue)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index d9fe072..d1109d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -62,9 +62,11 @@ class PythonUDFRunner(
startTime: Long,
env: SparkEnv,
worker: Socket,
+ pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
- new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {
+ new ReaderIterator(
+ stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
protected override def read(): Array[Byte] = {
if (writerThread.exception.isDefined) {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org