You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2017/03/03 00:46:41 UTC

spark git commit: [SPARK-19276][CORE] Fetch Failure handling robust to user error handling

Repository: spark
Updated Branches:
  refs/heads/master 433d9eb61 -> 8417a7ae6


[SPARK-19276][CORE] Fetch Failure handling robust to user error handling

## What changes were proposed in this pull request?

Fault-tolerance in spark requires special handling of shuffle fetch
failures.  The Executor would catch FetchFailedException and send a
special msg back to the driver.

However, intervening user code could intercept that exception, and wrap
it with something else.  This even happens in SparkSQL.  So rather than
checking the thrown exception only, we'll store the fetch failure directly
in the TaskContext, where users can't touch it.

## How was this patch tested?

Added a test case which failed before the fix.  Full test suite via jenkins.

Author: Imran Rashid <ir...@cloudera.com>

Closes #16639 from squito/SPARK-19276.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8417a7ae
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8417a7ae
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8417a7ae

Branch: refs/heads/master
Commit: 8417a7ae6c0ea3fb8dc41bc492fc9513d1ad24af
Parents: 433d9eb
Author: Imran Rashid <ir...@cloudera.com>
Authored: Thu Mar 2 16:46:01 2017 -0800
Committer: Kay Ousterhout <ka...@gmail.com>
Committed: Thu Mar 2 16:46:01 2017 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/TaskContext.scala    |   7 +
 .../org/apache/spark/TaskContextImpl.scala      |  11 ++
 .../org/apache/spark/executor/Executor.scala    |  33 ++++-
 .../scala/org/apache/spark/scheduler/Task.scala |   9 +-
 .../spark/shuffle/FetchFailedException.scala    |  13 +-
 .../apache/spark/executor/ExecutorSuite.scala   | 139 ++++++++++++++++++-
 project/MimaExcludes.scala                      |   3 +
 7 files changed, 198 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8417a7ae/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 0fd777e..f0867ec 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.source.Source
+import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}
 
 
@@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable {
    */
   private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit
 
+  /**
+   * Record that this task has failed due to a fetch failure from a remote host.  This allows
+   * fetch-failure handling to get triggered by the driver, regardless of intervening user-code.
+   */
+  private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8417a7ae/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index c904e08..dc0d128 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.metrics.source.Source
+import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util._
 
 private[spark] class TaskContextImpl(
@@ -56,6 +57,10 @@ private[spark] class TaskContextImpl(
   // Whether the task has failed.
   @volatile private var failed: Boolean = false
 
+  // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
+  // hide the exception.  See SPARK-19276
+  @volatile private var _fetchFailedException: Option[FetchFailedException] = None
+
   override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
     onCompleteCallbacks += listener
     this
@@ -126,4 +131,10 @@ private[spark] class TaskContextImpl(
     taskMetrics.registerAccumulator(a)
   }
 
+  private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
+    this._fetchFailedException = Option(fetchFailed)
+  }
+
+  private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8417a7ae/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 975a6e4..790c1ae 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.executor
 
 import java.io.{File, NotSerializableException}
+import java.lang.Thread.UncaughtExceptionHandler
 import java.lang.management.ManagementFactory
 import java.net.{URI, URL}
 import java.nio.ByteBuffer
@@ -52,7 +53,8 @@ private[spark] class Executor(
     executorHostname: String,
     env: SparkEnv,
     userClassPath: Seq[URL] = Nil,
-    isLocal: Boolean = false)
+    isLocal: Boolean = false,
+    uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler)
   extends Logging {
 
   logInfo(s"Starting executor ID $executorId on host $executorHostname")
@@ -78,7 +80,7 @@ private[spark] class Executor(
     // Setup an uncaught exception handler for non-local mode.
     // Make any thread terminations due to uncaught exceptions kill the entire
     // executor process to avoid surprising stalls.
-    Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
+    Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler)
   }
 
   // Start worker thread pool
@@ -342,6 +344,14 @@ private[spark] class Executor(
             }
           }
         }
+        task.context.fetchFailed.foreach { fetchFailure =>
+          // uh-oh.  it appears the user code has caught the fetch-failure without throwing any
+          // other exceptions.  Its *possible* this is what the user meant to do (though highly
+          // unlikely).  So we will log an error and keep going.
+          logError(s"TID ${taskId} completed successfully though internally it encountered " +
+            s"unrecoverable fetch failures!  Most likely this means user code is incorrectly " +
+            s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
+        }
         val taskFinish = System.currentTimeMillis()
         val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
           threadMXBean.getCurrentThreadCpuTime
@@ -402,8 +412,17 @@ private[spark] class Executor(
         execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
 
       } catch {
-        case ffe: FetchFailedException =>
-          val reason = ffe.toTaskFailedReason
+        case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
+          val reason = task.context.fetchFailed.get.toTaskFailedReason
+          if (!t.isInstanceOf[FetchFailedException]) {
+            // there was a fetch failure in the task, but some user code wrapped that exception
+            // and threw something else.  Regardless, we treat it as a fetch failure.
+            val fetchFailedCls = classOf[FetchFailedException].getName
+            logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
+              s"failed, but the ${fetchFailedCls} was hidden by another " +
+              s"exception.  Spark is handling this like a fetch failure and ignoring the " +
+              s"other exception: $t")
+          }
           setTaskFinishedAndClearInterruptStatus()
           execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
 
@@ -455,13 +474,17 @@ private[spark] class Executor(
           // Don't forcibly exit unless the exception was inherently fatal, to avoid
           // stopping other tasks unnecessarily.
           if (Utils.isFatalError(t)) {
-            SparkUncaughtExceptionHandler.uncaughtException(t)
+            uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
           }
 
       } finally {
         runningTasks.remove(taskId)
       }
     }
+
+    private def hasFetchFailure: Boolean = {
+      task != null && task.context != null && task.context.fetchFailed.isDefined
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/8417a7ae/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 7b726d5..7021372 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,19 +17,14 @@
 
 package org.apache.spark.scheduler
 
-import java.io.{DataInputStream, DataOutputStream}
 import java.nio.ByteBuffer
 import java.util.Properties
 
-import scala.collection.mutable
-import scala.collection.mutable.HashMap
-
 import org.apache.spark._
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.internal.config.APP_CALLER_CONTEXT
 import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
 import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.serializer.SerializerInstance
 import org.apache.spark.util._
 
 /**
@@ -137,6 +132,8 @@ private[spark] abstract class Task[T](
           memoryManager.synchronized { memoryManager.notifyAll() }
         }
       } finally {
+        // Though we unset the ThreadLocal here, the context member variable itself is still queried
+        // directly in the TaskRunner to check for FetchFailedExceptions.
         TaskContext.unset()
       }
     }
@@ -156,7 +153,7 @@ private[spark] abstract class Task[T](
   var epoch: Long = -1
 
   // Task context, to be initialized in run().
-  @transient protected var context: TaskContextImpl = _
+  @transient var context: TaskContextImpl = _
 
   // The actual Thread on which the task is running, if any. Initialized in run().
   @volatile @transient private var taskThread: Thread = _

http://git-wip-us.apache.org/repos/asf/spark/blob/8417a7ae/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 498c12e..265a8ac 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.shuffle
 
-import org.apache.spark.{FetchFailed, TaskFailedReason}
+import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason}
 import org.apache.spark.storage.BlockManagerId
 import org.apache.spark.util.Utils
 
@@ -26,6 +26,11 @@ import org.apache.spark.util.Utils
  * back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage.
  *
  * Note that bmAddress can be null.
+ *
+ * To prevent user code from hiding this fetch failure, in the constructor we call
+ * [[TaskContext.setFetchFailed()]].  This means that you *must* throw this exception immediately
+ * after creating it -- you cannot create it, check some condition, and then decide to ignore it
+ * (or risk triggering any other exceptions).  See SPARK-19276.
  */
 private[spark] class FetchFailedException(
     bmAddress: BlockManagerId,
@@ -45,6 +50,12 @@ private[spark] class FetchFailedException(
     this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
   }
 
+  // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code
+  // which intercepts this exception (possibly wrapping it), the Executor can still tell there was
+  // a fetch failure, and send the correct error msg back to the driver.  We wrap with an Option
+  // because the TaskContext is not defined in some test cases.
+  Option(TaskContext.get()).map(_.setFetchFailed(this))
+
   def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
     Utils.exceptionString(this))
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8417a7ae/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index b743ff5..8150fff 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.executor
 
 import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.lang.Thread.UncaughtExceptionHandler
 import java.nio.ByteBuffer
 import java.util.Properties
 import java.util.concurrent.{CountDownLatch, TimeUnit}
@@ -27,7 +28,7 @@ import scala.concurrent.duration._
 
 import org.mockito.ArgumentCaptor
 import org.mockito.Matchers.{any, eq => meq}
-import org.mockito.Mockito.{inOrder, when}
+import org.mockito.Mockito.{inOrder, verify, when}
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
 import org.scalatest.concurrent.Eventually
@@ -37,9 +38,12 @@ import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
 import org.apache.spark.memory.MemoryManager
 import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.rdd.RDD
 import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.scheduler.{FakeTask, TaskDescription}
+import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
 import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.storage.BlockManagerId
 
 class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
 
@@ -123,6 +127,75 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
     }
   }
 
+  test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") {
+    val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
+    sc = new SparkContext(conf)
+    val serializer = SparkEnv.get.closureSerializer.newInstance()
+    val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
+
+    // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides
+    // the fetch failure.  The executor should still tell the driver that the task failed due to a
+    // fetch failure, not a generic exception from user code.
+    val inputRDD = new FetchFailureThrowingRDD(sc)
+    val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
+    val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
+    val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
+    val task = new ResultTask(
+      stageId = 1,
+      stageAttemptId = 0,
+      taskBinary = taskBinary,
+      partition = secondRDD.partitions(0),
+      locs = Seq(),
+      outputId = 0,
+      localProperties = new Properties(),
+      serializedTaskMetrics = serializedTaskMetrics
+    )
+
+    val serTask = serializer.serialize(task)
+    val taskDescription = createFakeTaskDescription(serTask)
+
+    val failReason = runTaskAndGetFailReason(taskDescription)
+    assert(failReason.isInstanceOf[FetchFailed])
+  }
+
+  test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
+    // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
+    // may be a false positive.  And we should call the uncaught exception handler.
+    val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
+    sc = new SparkContext(conf)
+    val serializer = SparkEnv.get.closureSerializer.newInstance()
+    val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
+
+    // Submit a job where a fetch failure is thrown, but then there is an OOM.  We should treat
+    // the fetch failure as a false positive, and just do normal OOM handling.
+    val inputRDD = new FetchFailureThrowingRDD(sc)
+    val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
+    val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
+    val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
+    val task = new ResultTask(
+      stageId = 1,
+      stageAttemptId = 0,
+      taskBinary = taskBinary,
+      partition = secondRDD.partitions(0),
+      locs = Seq(),
+      outputId = 0,
+      localProperties = new Properties(),
+      serializedTaskMetrics = serializedTaskMetrics
+    )
+
+    val serTask = serializer.serialize(task)
+    val taskDescription = createFakeTaskDescription(serTask)
+
+    val (failReason, uncaughtExceptionHandler) =
+      runTaskGetFailReasonAndExceptionHandler(taskDescription)
+    // make sure the task failure just looks like a OOM, not a fetch failure
+    assert(failReason.isInstanceOf[ExceptionFailure])
+    val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
+    verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
+    assert(exceptionCaptor.getAllValues.size === 1)
+    assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
+  }
+
   test("Gracefully handle error in task deserialization") {
     val conf = new SparkConf
     val serializer = new JavaSerializer(conf)
@@ -169,13 +242,20 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
   }
 
   private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
+    runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
+  }
+
+  private def runTaskGetFailReasonAndExceptionHandler(
+      taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
     val mockBackend = mock[ExecutorBackend]
+    val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
     var executor: Executor = null
     try {
-      executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
+      executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
+        uncaughtExceptionHandler = mockUncaughtExceptionHandler)
       // the task will be launched in a dedicated worker thread
       executor.launchTask(mockBackend, taskDescription)
-      eventually(timeout(5 seconds), interval(10 milliseconds)) {
+      eventually(timeout(5.seconds), interval(10.milliseconds)) {
         assert(executor.numRunningTasks === 0)
       }
     } finally {
@@ -193,7 +273,56 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
     assert(statusCaptor.getAllValues().get(0).remaining() === 0)
     // second update is more interesting
     val failureData = statusCaptor.getAllValues.get(1)
-    SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
+    val failReason =
+      SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
+    (failReason, mockUncaughtExceptionHandler)
+  }
+}
+
+class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
+  override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
+    new Iterator[Int] {
+      override def hasNext: Boolean = true
+      override def next(): Int = {
+        throw new FetchFailedException(
+          bmAddress = BlockManagerId("1", "hostA", 1234),
+          shuffleId = 0,
+          mapId = 0,
+          reduceId = 0,
+          message = "fake fetch failure"
+        )
+      }
+    }
+  }
+  override protected def getPartitions: Array[Partition] = {
+    Array(new SimplePartition)
+  }
+}
+
+class SimplePartition extends Partition {
+  override def index: Int = 0
+}
+
+class FetchFailureHidingRDD(
+    sc: SparkContext,
+    val input: FetchFailureThrowingRDD,
+    throwOOM: Boolean) extends RDD[Int](input) {
+  override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
+    val inItr = input.compute(split, context)
+    try {
+      Iterator(inItr.size)
+    } catch {
+      case t: Throwable =>
+        if (throwOOM) {
+          throw new OutOfMemoryError("OOM while handling another exception")
+        } else {
+          throw new RuntimeException("User Exception that hides the original exception", t)
+        }
+    }
+  }
+
+  override protected def getPartitions: Array[Partition] = {
+    Array(new SimplePartition)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8417a7ae/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 511686f..56b8c0b 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -55,6 +55,9 @@ object MimaExcludes {
     // [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"),
 
+    // [SPARK-19267] Fetch Failure handling robust to user error handling
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"),
+
     // [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API.
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org