You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by va...@apache.org on 2016/03/29 01:29:23 UTC

spark git commit: [SPARK-14169][CORE] Add UninterruptibleThread

Repository: spark
Updated Branches:
  refs/heads/master b7836492b -> 2f98ee67d


[SPARK-14169][CORE] Add UninterruptibleThread

## What changes were proposed in this pull request?

Extract the workaround for HADOOP-10622 introduced by #11940 into UninterruptibleThread so that we can test and reuse it.

## How was this patch tested?

Unit tests

Author: Shixiong Zhu <sh...@databricks.com>

Closes #11971 from zsxwing/uninterrupt.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f98ee67
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f98ee67
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f98ee67

Branch: refs/heads/master
Commit: 2f98ee67dff0be38a4c92d7d29c8cc8ea8b6576e
Parents: b783649
Author: Shixiong Zhu <sh...@databricks.com>
Authored: Mon Mar 28 16:29:11 2016 -0700
Committer: Marcelo Vanzin <va...@cloudera.com>
Committed: Mon Mar 28 16:29:11 2016 -0700

----------------------------------------------------------------------
 .../spark/util/UninterruptibleThread.scala      | 112 +++++++++++++
 .../spark/util/UninterruptibleThreadSuite.scala | 159 +++++++++++++++++++
 .../execution/streaming/StreamExecution.scala   |  74 +--------
 3 files changed, 279 insertions(+), 66 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2f98ee67/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
new file mode 100644
index 0000000..4dcf951
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import javax.annotation.concurrent.GuardedBy
+
+/**
+ * A special Thread that provides "runUninterruptibly" to allow running codes without being
+ * interrupted by `Thread.interrupt()`. If `Thread.interrupt()` is called during runUninterruptibly
+ * is running, it won't set the interrupted status. Instead, setting the interrupted status will be
+ * deferred until it's returning from "runUninterruptibly".
+ *
+ * Note: "runUninterruptibly" should be called only in `this` thread.
+ */
+private[spark] class UninterruptibleThread(name: String) extends Thread(name) {
+
+  /** A monitor to protect "uninterruptible" and "interrupted" */
+  private val uninterruptibleLock = new Object
+
+  /**
+   * Indicates if `this`  thread are in the uninterruptible status. If so, interrupting
+   * "this" will be deferred until `this`  enters into the interruptible status.
+   */
+  @GuardedBy("uninterruptibleLock")
+  private var uninterruptible = false
+
+  /**
+   * Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
+   */
+  @GuardedBy("uninterruptibleLock")
+  private var shouldInterruptThread = false
+
+  /**
+   * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
+   * from `f`.
+   *
+   * If this method finds that `interrupt` is called before calling `f` and it's not inside another
+   * `runUninterruptibly`, it will throw `InterruptedException`.
+   *
+   * Note: this method should be called only in `this` thread.
+   */
+  def runUninterruptibly[T](f: => T): T = {
+    if (Thread.currentThread() != this) {
+      throw new IllegalStateException(s"Call runUninterruptibly in a wrong thread. " +
+        s"Expected: $this but was ${Thread.currentThread()}")
+    }
+
+    if (uninterruptibleLock.synchronized { uninterruptible }) {
+      // We are already in the uninterruptible status. So just run "f" and return
+      return f
+    }
+
+    uninterruptibleLock.synchronized {
+      // Clear the interrupted status if it's set.
+      if (Thread.interrupted() || shouldInterruptThread) {
+        shouldInterruptThread = false
+        // Since it's interrupted, we don't need to run `f` which may be a long computation.
+        // Throw InterruptedException as we don't have a T to return.
+        throw new InterruptedException()
+      }
+      uninterruptible = true
+    }
+    try {
+      f
+    } finally {
+      uninterruptibleLock.synchronized {
+        uninterruptible = false
+        if (shouldInterruptThread) {
+          // Recover the interrupted status
+          super.interrupt()
+          shouldInterruptThread = false
+        }
+      }
+    }
+  }
+
+  /**
+   * Tests whether `interrupt()` has been called.
+   */
+  override def isInterrupted: Boolean = {
+    super.isInterrupted || uninterruptibleLock.synchronized { shouldInterruptThread }
+  }
+
+  /**
+   * Interrupt `this` thread if possible. If `this` is in the uninterruptible status, it won't be
+   * interrupted until it enters into the interruptible status.
+   */
+  override def interrupt(): Unit = {
+    uninterruptibleLock.synchronized {
+      if (uninterruptible) {
+        shouldInterruptThread = true
+      } else {
+        super.interrupt()
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2f98ee67/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala
new file mode 100644
index 0000000..39b31f8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.{CountDownLatch, TimeUnit}
+
+import scala.util.Random
+
+import com.google.common.util.concurrent.Uninterruptibles
+
+import org.apache.spark.SparkFunSuite
+
+class UninterruptibleThreadSuite extends SparkFunSuite {
+
+  /** Sleep millis and return true if it's interrupted */
+  private def sleep(millis: Long): Boolean = {
+    try {
+      Thread.sleep(millis)
+      false
+    } catch {
+      case _: InterruptedException =>
+        true
+    }
+  }
+
+  test("interrupt when runUninterruptibly is running") {
+    val enterRunUninterruptibly = new CountDownLatch(1)
+    @volatile var hasInterruptedException = false
+    @volatile var interruptStatusBeforeExit = false
+    val t = new UninterruptibleThread("test") {
+      override def run(): Unit = {
+        runUninterruptibly {
+          enterRunUninterruptibly.countDown()
+          hasInterruptedException = sleep(1000)
+        }
+        interruptStatusBeforeExit = Thread.interrupted()
+      }
+    }
+    t.start()
+    assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
+    t.interrupt()
+    t.join()
+    assert(hasInterruptedException === false)
+    assert(interruptStatusBeforeExit === true)
+  }
+
+  test("interrupt before runUninterruptibly runs") {
+    val interruptLatch = new CountDownLatch(1)
+    @volatile var hasInterruptedException = false
+    @volatile var interruptStatusBeforeExit = false
+    val t = new UninterruptibleThread("test") {
+      override def run(): Unit = {
+        Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
+        try {
+          runUninterruptibly {
+            assert(false, "Should not reach here")
+          }
+        } catch {
+          case _: InterruptedException => hasInterruptedException = true
+        }
+        interruptStatusBeforeExit = Thread.interrupted()
+      }
+    }
+    t.start()
+    t.interrupt()
+    interruptLatch.countDown()
+    t.join()
+    assert(hasInterruptedException === true)
+    assert(interruptStatusBeforeExit === false)
+  }
+
+  test("nested runUninterruptibly") {
+    val enterRunUninterruptibly = new CountDownLatch(1)
+    val interruptLatch = new CountDownLatch(1)
+    @volatile var hasInterruptedException = false
+    @volatile var interruptStatusBeforeExit = false
+    val t = new UninterruptibleThread("test") {
+      override def run(): Unit = {
+        runUninterruptibly {
+          enterRunUninterruptibly.countDown()
+          Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
+          hasInterruptedException = sleep(1)
+          runUninterruptibly {
+            if (sleep(1)) {
+              hasInterruptedException = true
+            }
+          }
+          if (sleep(1)) {
+            hasInterruptedException = true
+          }
+        }
+        interruptStatusBeforeExit = Thread.interrupted()
+      }
+    }
+    t.start()
+    assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
+    t.interrupt()
+    interruptLatch.countDown()
+    t.join()
+    assert(hasInterruptedException === false)
+    assert(interruptStatusBeforeExit === true)
+  }
+
+  test("stress test") {
+    @volatile var hasInterruptedException = false
+    val t = new UninterruptibleThread("test") {
+      override def run(): Unit = {
+        for (i <- 0 until 100) {
+          try {
+            runUninterruptibly {
+              if (sleep(Random.nextInt(10))) {
+                hasInterruptedException = true
+              }
+              runUninterruptibly {
+                if (sleep(Random.nextInt(10))) {
+                  hasInterruptedException = true
+                }
+              }
+              if (sleep(Random.nextInt(10))) {
+                hasInterruptedException = true
+              }
+            }
+            Uninterruptibles.sleepUninterruptibly(Random.nextInt(10), TimeUnit.MILLISECONDS)
+            // 50% chance to clear the interrupted status
+            if (Random.nextBoolean()) {
+              Thread.interrupted()
+            }
+          } catch {
+            case _: InterruptedException =>
+              // The first runUninterruptibly may throw InterruptedException if the interrupt status
+              // is set before running `f`.
+          }
+        }
+      }
+    }
+    t.start()
+    for (i <- 0 until 400) {
+      Thread.sleep(Random.nextInt(10))
+      t.interrupt()
+    }
+    t.join()
+    assert(hasInterruptedException === false)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2f98ee67/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 60e00d2..c4e410d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.streaming
 
 import java.util.concurrent.{CountDownLatch, TimeUnit}
 import java.util.concurrent.atomic.AtomicInteger
-import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
@@ -34,6 +33,7 @@ import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.util.ContinuousQueryListener
 import org.apache.spark.sql.util.ContinuousQueryListener._
+import org.apache.spark.util.UninterruptibleThread
 
 /**
  * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread.
@@ -89,9 +89,10 @@ class StreamExecution(
   private[sql] var streamDeathCause: ContinuousQueryException = null
 
   /** The thread that runs the micro-batches of this stream. */
-  private[sql] val microBatchThread = new Thread(s"stream execution thread for $name") {
-    override def run(): Unit = { runBatches() }
-  }
+  private[sql] val microBatchThread =
+    new UninterruptibleThread(s"stream execution thread for $name") {
+      override def run(): Unit = { runBatches() }
+    }
 
   /**
    * A write-ahead-log that records the offsets that are present in each batch. In order to ensure
@@ -102,65 +103,6 @@ class StreamExecution(
   private val offsetLog =
     new HDFSMetadataLog[CompositeOffset](sqlContext, checkpointFile("offsets"))
 
-  /** A monitor to protect "uninterruptible" and "interrupted" */
-  private val uninterruptibleLock = new Object
-
-  /**
-   * Indicates if "microBatchThread" are in the uninterruptible status. If so, interrupting
-   * "microBatchThread" will be deferred until "microBatchThread" enters into the interruptible
-   * status.
-   */
-  @GuardedBy("uninterruptibleLock")
-  private var uninterruptible = false
-
-  /**
-   * Indicates if we should interrupt "microBatchThread" when we are leaving the uninterruptible
-   * zone.
-   */
-  @GuardedBy("uninterruptibleLock")
-  private var shouldInterruptThread = false
-
-  /**
-   * Interrupt "microBatchThread" if possible. If "microBatchThread" is in the uninterruptible
-   * status, "microBatchThread" won't be interrupted until it enters into the interruptible status.
-   */
-  private def interruptMicroBatchThreadSafely(): Unit = {
-    uninterruptibleLock.synchronized {
-      if (uninterruptible) {
-        shouldInterruptThread = true
-      } else {
-        microBatchThread.interrupt()
-      }
-    }
-  }
-
-  /**
-   * Run `f` uninterruptibly in "microBatchThread". "microBatchThread" won't be interrupted before
-   * returning from `f`.
-   */
-  private def runUninterruptiblyInMicroBatchThread[T](f: => T): T = {
-    assert(Thread.currentThread() == microBatchThread)
-    uninterruptibleLock.synchronized {
-      uninterruptible = true
-      // Clear the interrupted status if it's set.
-      if (Thread.interrupted()) {
-        shouldInterruptThread = true
-      }
-    }
-    try {
-      f
-    } finally {
-      uninterruptibleLock.synchronized {
-        uninterruptible = false
-        if (shouldInterruptThread) {
-          // Recover the interrupted status
-          microBatchThread.interrupt()
-          shouldInterruptThread = false
-        }
-      }
-    }
-  }
-
   /** Whether the query is currently active or not */
   override def isActive: Boolean = state == ACTIVE
 
@@ -294,7 +236,7 @@ class StreamExecution(
     // method. See SPARK-14131.
     //
     // Check to see what new data is available.
-    val newData = runUninterruptiblyInMicroBatchThread {
+    val newData = microBatchThread.runUninterruptibly {
       uniqueSources.flatMap(s => s.getOffset.map(o => s -> o))
     }
     availableOffsets ++= newData
@@ -305,7 +247,7 @@ class StreamExecution(
       // As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set
       // the file permission, we should not interrupt "microBatchThread" when running this method.
       // See SPARK-14131.
-      runUninterruptiblyInMicroBatchThread {
+      microBatchThread.runUninterruptibly {
         assert(
           offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)),
           s"Concurrent update to the log.  Multiple streaming jobs detected for $currentBatchId")
@@ -395,7 +337,7 @@ class StreamExecution(
     // intentionally
     state = TERMINATED
     if (microBatchThread.isAlive) {
-      interruptMicroBatchThreadSafely()
+      microBatchThread.interrupt()
       microBatchThread.join()
     }
     logInfo(s"Query $name was stopped")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org