You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/12/21 00:05:09 UTC

spark git commit: [SPARK-18928][BRANCH-2.0] Check TaskContext.isInterrupted() in FileScanRDD, JDBCRDD & UnsafeSorter

Repository: spark
Updated Branches:
  refs/heads/branch-2.0 678d91c1d -> 2aae220b5


[SPARK-18928][BRANCH-2.0] Check TaskContext.isInterrupted() in FileScanRDD, JDBCRDD & UnsafeSorter

This is a branch-2.0 backport of #16340; the original description follows:

## What changes were proposed in this pull request?

In order to respond to task cancellation, Spark tasks must periodically check `TaskContext.isInterrupted()`, but this check is missing on a few critical read paths used in Spark SQL, including `FileScanRDD`, `JDBCRDD`, and UnsafeSorter-based sorts. This can cause interrupted / cancelled tasks to continue running and become zombies (as also described in #16189).

This patch aims to fix this problem by adding `TaskContext.isInterrupted()` checks to these paths. Note that I could have used `InterruptibleIterator` to simply wrap a bunch of iterators but in some cases this would have an adverse performance penalty or might not be effective due to certain special uses of Iterators in Spark SQL. Instead, I inlined `InterruptibleIterator`-style logic into existing iterator subclasses.

## How was this patch tested?

Tested manually in `spark-shell` with two different reproductions of non-cancellable tasks, one involving scans of huge files and another involving sort-merge joins that spill to disk. Both causes of zombie tasks are fixed by the changes added here.

Author: Josh Rosen <jo...@databricks.com>

Closes #16357 from JoshRosen/sql-task-interruption-branch-2.0.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2aae220b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2aae220b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2aae220b

Branch: refs/heads/branch-2.0
Commit: 2aae220b536065f55b2cf644a2a223aab0d051d0
Parents: 678d91c
Author: Josh Rosen <jo...@databricks.com>
Authored: Tue Dec 20 16:05:04 2016 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Tue Dec 20 16:05:04 2016 -0800

----------------------------------------------------------------------
 .../collection/unsafe/sort/UnsafeInMemorySorter.java    | 11 +++++++++++
 .../collection/unsafe/sort/UnsafeSorterSpillReader.java | 11 +++++++++++
 .../spark/sql/execution/datasources/FileScanRDD.scala   | 12 ++++++++++--
 .../spark/sql/execution/datasources/jdbc/JDBCRDD.scala  |  9 ++++++++-
 4 files changed, 40 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2aae220b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index b517371..2bd756f 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -21,6 +21,8 @@ import java.util.Comparator;
 
 import org.apache.avro.reflect.Nullable;
 
+import org.apache.spark.TaskContext;
+import org.apache.spark.TaskKilledException;
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
@@ -226,6 +228,7 @@ public final class UnsafeInMemorySorter {
     private long keyPrefix;
     private int recordLength;
     private long currentPageNumber;
+    private final TaskContext taskContext = TaskContext.get();
 
     private SortedIterator(int numRecords, int offset) {
       this.numRecords = numRecords;
@@ -256,6 +259,14 @@ public final class UnsafeInMemorySorter {
 
     @Override
     public void loadNext() {
+      // Kill the task in case it has been marked as killed. This logic is from
+      // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+      // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+      // `hasNext()` because it's technically possible for the caller to be relying on
+      // `getNumRecords()` instead of `hasNext()` to know when to stop.
+      if (taskContext != null && taskContext.isInterrupted()) {
+        throw new TaskKilledException();
+      }
       // This pointer points to a 4-byte record length, followed by the record's bytes
       final long recordPointer = array.get(offset + position);
       currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);

http://git-wip-us.apache.org/repos/asf/spark/blob/2aae220b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 1d588c3..a3f04de 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -22,6 +22,8 @@ import java.io.*;
 import com.google.common.io.ByteStreams;
 import com.google.common.io.Closeables;
 
+import org.apache.spark.TaskContext;
+import org.apache.spark.TaskKilledException;
 import org.apache.spark.serializer.SerializerManager;
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.unsafe.Platform;
@@ -44,6 +46,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
   private byte[] arr = new byte[1024 * 1024];
   private Object baseObject = arr;
   private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
+  private final TaskContext taskContext = TaskContext.get();
 
   public UnsafeSorterSpillReader(
       SerializerManager serializerManager,
@@ -73,6 +76,14 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
 
   @Override
   public void loadNext() throws IOException {
+    // Kill the task in case it has been marked as killed. This logic is from
+    // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+    // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+    // `hasNext()` because it's technically possible for the caller to be relying on
+    // `getNumRecords()` instead of `hasNext()` to know when to stop.
+    if (taskContext != null && taskContext.isInterrupted()) {
+      throw new TaskKilledException();
+    }
     recordLength = din.readInt();
     keyPrefix = din.readLong();
     if (recordLength > arr.length) {

http://git-wip-us.apache.org/repos/asf/spark/blob/2aae220b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 1314c94..b4deaf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources
 
 import scala.collection.mutable
 
-import org.apache.spark.{Partition => RDDPartition, TaskContext}
+import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.rdd.{InputFileNameHolder, RDD}
 import org.apache.spark.sql.SparkSession
@@ -88,7 +88,15 @@ class FileScanRDD(
       private[this] var currentFile: PartitionedFile = null
       private[this] var currentIterator: Iterator[Object] = null
 
-      def hasNext = (currentIterator != null && currentIterator.hasNext) || nextIterator()
+      def hasNext: Boolean = {
+        // Kill the task in case it has been marked as killed. This logic is from
+        // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+        // to avoid performance overhead.
+        if (context.isInterrupted()) {
+          throw new TaskKilledException
+        }
+        (currentIterator != null && currentIterator.hasNext) || nextIterator()
+      }
       def next() = {
         val nextElement = currentIterator.next()
         // TODO: we should have a better separation of row based and batch based scan, so that we

http://git-wip-us.apache.org/repos/asf/spark/blob/2aae220b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 95058cc..a0afe34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -24,7 +24,7 @@ import scala.util.control.NonFatal
 
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.{Partition, SparkContext, TaskContext, TaskKilledException}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -541,6 +541,13 @@ private[jdbc] class JDBCRDD(
     }
 
     override def hasNext: Boolean = {
+      // Kill the task in case it has been marked as killed. This logic is from
+      // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+      // to avoid performance overhead and to minimize modified code since it's not easy to
+      // wrap this Iterator without re-indenting tons of code.
+      if (context.isInterrupted()) {
+        throw new TaskKilledException
+      }
       if (!finished) {
         if (!gotNext) {
           nextValue = getNext()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org