You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/10 07:01:19 UTC

spark git commit: [SPARK-22982] Remove unsafe asynchronous close() call from FileDownloadChannel

Repository: spark
Updated Branches:
  refs/heads/master e59983724 -> edf0a48c2


[SPARK-22982] Remove unsafe asynchronous close() call from FileDownloadChannel

## What changes were proposed in this pull request?

This patch fixes a severe asynchronous IO bug in Spark's Netty-based file transfer code. At a high-level, the problem is that an unsafe asynchronous `close()` of a pipe's source channel creates a race condition where file transfer code closes a file descriptor then attempts to read from it. If the closed file descriptor's number has been reused by an `open()` call then this invalid read may cause unrelated file operations to return incorrect results. **One manifestation of this problem is incorrect query results.**

For a high-level overview of how file download works, take a look at the control flow in `NettyRpcEnv.openChannel()`: this code creates a pipe to buffer results, then submits an asynchronous stream request to a lower-level TransportClient. The callback passes received data to the sink end of the pipe. The source end of the pipe is passed back to the caller of `openChannel()`. Thus `openChannel()` returns immediately and callers interact with the returned pipe source channel.

Because the underlying stream request is asynchronous, errors may occur after `openChannel()` has returned and after that method's caller has started to `read()` from the returned channel. For example, if a client requests an invalid stream from a remote server then the "stream does not exist" error may not be received from the remote server until after `openChannel()` has returned. In order to be able to propagate the "stream does not exist" error to the file-fetching application thread, this code wraps the pipe's source channel in a special `FileDownloadChannel` which adds an `setError(t: Throwable)` method, then calls this `setError()` method in the FileDownloadCallback's `onFailure` method.

It is possible for `FileDownloadChannel`'s `read()` and `setError()` methods to be called concurrently from different threads: the `setError()` method is called from within the Netty RPC system's stream callback handlers, while the `read()` methods are called from higher-level application code performing remote stream reads.

The problem lies in `setError()`: the existing code closed the wrapped pipe source channel. Because `read()` and `setError()` occur in different threads, this means it is possible for one thread to be calling `source.read()` while another asynchronously calls `source.close()`. Java's IO libraries do not guarantee that this will be safe and, in fact, it's possible for these operations to interleave in such a way that a lower-level `read()` system call occurs right after a `close()` call. In the best-case, this fails as a read of a closed file descriptor; in the worst-case, the file descriptor number has been re-used by an intervening `open()` operation and the read corrupts the result of an unrelated file IO operation being performed by a different thread.

The solution here is to remove the `stream.close()` call in `onError()`: the thread that is performing the `read()` calls is responsible for closing the stream in a `finally` block, so there's no need to close it here. If that thread is blocked in a `read()` then it will become unblocked when the sink end of the pipe is closed in `FileDownloadCallback.onFailure()`.

After making this change, we also need to refine the `read()` method to always check for a `setError()` result, even if the underlying channel `read()` call has succeeded.

This patch also makes a slight cleanup to a dodgy-looking `catch e: Exception` block to use a safer `try-finally` error handling idiom.

This bug was introduced in SPARK-11956 / #9941 and is present in Spark 1.6.0+.

## How was this patch tested?

This fix was tested manually against a workload which non-deterministically hit this bug.

Author: Josh Rosen <jo...@databricks.com>

Closes #20179 from JoshRosen/SPARK-22982-fix-unsafe-async-io-in-file-download-channel.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/edf0a48c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/edf0a48c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/edf0a48c

Branch: refs/heads/master
Commit: edf0a48c2ec696b92ed6a96dcee6eeb1a046b20b
Parents: e599837
Author: Josh Rosen <jo...@databricks.com>
Authored: Wed Jan 10 15:01:11 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Jan 10 15:01:11 2018 +0800

----------------------------------------------------------------------
 .../apache/spark/rpc/netty/NettyRpcEnv.scala    | 37 ++++++++++++--------
 .../shuffle/IndexShuffleBlockResolver.scala     | 21 ++++++++---
 2 files changed, 39 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/edf0a48c/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index f951591..a2936d6 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv(
 
     val pipe = Pipe.open()
     val source = new FileDownloadChannel(pipe.source())
-    try {
+    Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
       val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
       val callback = new FileDownloadCallback(pipe.sink(), source, client)
       client.stream(parsedUri.getPath(), callback)
-    } catch {
-      case e: Exception =>
-        pipe.sink().close()
-        source.close()
-        throw e
-    }
+    })(catchBlock = {
+      pipe.sink().close()
+      source.close()
+    })
 
     source
   }
@@ -370,24 +368,33 @@ private[netty] class NettyRpcEnv(
     fileDownloadFactory.createClient(host, port)
   }
 
-  private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {
+  private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel {
 
     @volatile private var error: Throwable = _
 
     def setError(e: Throwable): Unit = {
+      // This setError callback is invoked by internal RPC threads in order to propagate remote
+      // exceptions to application-level threads which are reading from this channel. When an
+      // RPC error occurs, the RPC system will call setError() and then will close the
+      // Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe
+      // sink will cause `source.read()` operations to return EOF, unblocking the application-level
+      // reading thread. Thus there is no need to actually call `source.close()` here in the
+      // onError() callback and, in fact, calling it here would be dangerous because the close()
+      // would be asynchronous with respect to the read() call and could trigger race-conditions
+      // that lead to data corruption. See the PR for SPARK-22982 for more details on this topic.
       error = e
-      source.close()
     }
 
     override def read(dst: ByteBuffer): Int = {
       Try(source.read(dst)) match {
+        // See the documentation above in setError(): if an RPC error has occurred then setError()
+        // will be called to propagate the RPC error and then `source`'s corresponding
+        // Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate
+        // the remote RPC exception (and not any exceptions triggered by the pipe close, such as
+        // ChannelClosedException), hence this `error != null` check:
+        case _ if error != null => throw error
         case Success(bytesRead) => bytesRead
-        case Failure(readErr) =>
-          if (error != null) {
-            throw error
-          } else {
-            throw readErr
-          }
+        case Failure(readErr) => throw readErr
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/edf0a48c/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 1554048..266ee42 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -18,8 +18,8 @@
 package org.apache.spark.shuffle
 
 import java.io._
-
-import com.google.common.io.ByteStreams
+import java.nio.channels.Channels
+import java.nio.file.Files
 
 import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
@@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver(
     // find out the consolidated file, then the offset within that from our index
     val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
 
-    val in = new DataInputStream(new FileInputStream(indexFile))
+    // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code
+    // which is incorrectly using our file descriptor then this code will fetch the wrong offsets
+    // (which may cause a reducer to be sent a different reducer's data). The explicit position
+    // checks added here were a useful debugging aid during SPARK-22982 and may help prevent this
+    // class of issue from re-occurring in the future which is why they are left here even though
+    // SPARK-22982 is fixed.
+    val channel = Files.newByteChannel(indexFile.toPath)
+    channel.position(blockId.reduceId * 8)
+    val in = new DataInputStream(Channels.newInputStream(channel))
     try {
-      ByteStreams.skipFully(in, blockId.reduceId * 8)
       val offset = in.readLong()
       val nextOffset = in.readLong()
+      val actualPosition = channel.position()
+      val expectedPosition = blockId.reduceId * 8 + 16
+      if (actualPosition != expectedPosition) {
+        throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " +
+          s"expected $expectedPosition but actual position was $actualPosition.")
+      }
       new FileSegmentManagedBuffer(
         transportConf,
         getDataFile(blockId.shuffleId, blockId.mapId),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org