You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/04/02 01:06:14 UTC

spark git commit: [SPARK-6578] [core] Fix thread-safety issue in outbound path of network library.

Repository: spark
Updated Branches:
  refs/heads/master fb25e8c7f -> f084c5de1


[SPARK-6578] [core] Fix thread-safety issue in outbound path of network library.

While the inbound path of a netty pipeline is thread-safe, the outbound
path is not. That means that multiple threads can compete to write messages
to the next stage of the pipeline.

The network library sometimes breaks a single RPC message into multiple
buffers internally to avoid copying data (see MessageEncoder). This can
result in the following scenario (where "FxBy" means "frame x, buffer y"):

               T1         F1B1            F1B2
                            \               \
                             \               \
               socket        F1B1   F2B1    F1B2  F2B2
                                     /             /
                                    /             /
               T2                  F2B1         F2B2

And the frames now cannot be rebuilt on the receiving side because the
different messages have been mixed up on the wire.

The fix wraps these multi-buffer messages into a `FileRegion` object
so that these messages are written "atomically" to the next pipeline handler.

Author: Marcelo Vanzin <va...@cloudera.com>

Closes #5234 from vanzin/SPARK-6578 and squashes the following commits:

16b2d70 [Marcelo Vanzin] Forgot to update a type.
c9c2e4e [Marcelo Vanzin] Review comments: simplify some code.
9c888ac [Marcelo Vanzin] Small style nits.
8474bab [Marcelo Vanzin] Fix multiple calls to MessageWithHeader.transferTo().
e26509f [Marcelo Vanzin] Merge branch 'master' into SPARK-6578
c503f6c [Marcelo Vanzin] Implement a custom FileRegion instead of using locks.
84aa7ce [Marcelo Vanzin] Rename handler to the correct name.
432f3bd [Marcelo Vanzin] Remove unneeded method.
8d70e60 [Marcelo Vanzin] Fix thread-safety issue in outbound path of network library.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f084c5de
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f084c5de
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f084c5de

Branch: refs/heads/master
Commit: f084c5de14eb10a6aba82a39e03e7877926ebb9e
Parents: fb25e8c
Author: Marcelo Vanzin <va...@cloudera.com>
Authored: Wed Apr 1 16:06:11 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Apr 1 16:06:11 2015 -0700

----------------------------------------------------------------------
 network/common/pom.xml                          |   5 +
 .../spark/network/protocol/MessageEncoder.java  |   6 +-
 .../network/protocol/MessageWithHeader.java     | 106 +++++++++++++++
 .../spark/network/ByteArrayWritableChannel.java |  55 ++++++++
 .../org/apache/spark/network/ProtocolSuite.java |  46 +++++--
 .../protocol/MessageWithHeaderSuite.java        | 129 +++++++++++++++++++
 .../common/src/test/resources/log4j.properties  |  27 ++++
 7 files changed, 364 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f084c5de/network/common/pom.xml
----------------------------------------------------------------------
diff --git a/network/common/pom.xml b/network/common/pom.xml
index 7b51845..22c738b 100644
--- a/network/common/pom.xml
+++ b/network/common/pom.xml
@@ -80,6 +80,11 @@
       <artifactId>mockito-all</artifactId>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.slf4j</groupId>
+      <artifactId>slf4j-log4j12</artifactId>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
 
   <build>

http://git-wip-us.apache.org/repos/asf/spark/blob/f084c5de/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 91d1e8a..0f999f5 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -72,9 +72,11 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
     in.encode(header);
     assert header.writableBytes() == 0;
 
-    out.add(header);
     if (body != null && bodyLength > 0) {
-      out.add(body);
+      out.add(new MessageWithHeader(header, body, bodyLength));
+    } else {
+      out.add(header);
     }
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f084c5de/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
new file mode 100644
index 0000000..215a851
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.io.IOException;
+import java.nio.channels.WritableByteChannel;
+
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Ints;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+import io.netty.util.ReferenceCountUtil;
+
+/**
+ * A wrapper message that holds two separate pieces (a header and a body) to avoid
+ * copying the body's content.
+ */
+class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
+
+  private final ByteBuf header;
+  private final int headerLength;
+  private final Object body;
+  private final long bodyLength;
+  private long totalBytesTransferred;
+
+  MessageWithHeader(ByteBuf header, Object body, long bodyLength) {
+    Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
+      "Body must be a ByteBuf or a FileRegion.");
+    this.header = header;
+    this.headerLength = header.readableBytes();
+    this.body = body;
+    this.bodyLength = bodyLength;
+  }
+
+  @Override
+  public long count() {
+    return headerLength + bodyLength;
+  }
+
+  @Override
+  public long position() {
+    return 0;
+  }
+
+  @Override
+  public long transfered() {
+    return totalBytesTransferred;
+  }
+
+  @Override
+  public long transferTo(WritableByteChannel target, long position) throws IOException {
+    Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position.");
+    long written = 0;
+
+    if (position < headerLength) {
+      written += copyByteBuf(header, target);
+      if (header.readableBytes() > 0) {
+        totalBytesTransferred += written;
+        return written;
+      }
+    }
+
+    if (body instanceof FileRegion) {
+      // Adjust the position. If the write is happening as part of the same call where the header
+      // (or some part of it) is written, `position` will be less than the header size, so we want
+      // to start from position 0 in the FileRegion object. Otherwise, we start from the position
+      // requested by the caller.
+      long bodyPos = position > headerLength ? position - headerLength : 0;
+      written += ((FileRegion)body).transferTo(target, bodyPos);
+    } else if (body instanceof ByteBuf) {
+      written += copyByteBuf((ByteBuf) body, target);
+    }
+
+    totalBytesTransferred += written;
+    return written;
+  }
+
+  @Override
+  protected void deallocate() {
+    header.release();
+    ReferenceCountUtil.release(body);
+  }
+
+  private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {
+    int written = target.write(buf.nioBuffer());
+    buf.skipBytes(written);
+    return written;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f084c5de/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java
new file mode 100644
index 0000000..b525ed6
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+
+public class ByteArrayWritableChannel implements WritableByteChannel {
+
+  private final byte[] data;
+  private int offset;
+
+  public ByteArrayWritableChannel(int size) {
+    this.data = new byte[size];
+    this.offset = 0;
+  }
+
+  public byte[] getData() {
+    return data;
+  }
+
+  @Override
+  public int write(ByteBuffer src) {
+    int available = src.remaining();
+    src.get(data, offset, available);
+    offset += available;
+    return available;
+  }
+
+  @Override
+  public void close() {
+
+  }
+
+  @Override
+  public boolean isOpen() {
+    return true;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f084c5de/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 43dc0cf..860dd6d 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -17,26 +17,34 @@
 
 package org.apache.spark.network;
 
+import java.util.List;
+
+import com.google.common.primitives.Ints;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.FileRegion;
 import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.MessageToMessageEncoder;
 import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
 
-import org.apache.spark.network.protocol.Message;
-import org.apache.spark.network.protocol.StreamChunkId;
-import org.apache.spark.network.protocol.ChunkFetchRequest;
 import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchRequest;
 import org.apache.spark.network.protocol.ChunkFetchSuccess;
-import org.apache.spark.network.protocol.RpcRequest;
-import org.apache.spark.network.protocol.RpcFailure;
-import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.Message;
 import org.apache.spark.network.protocol.MessageDecoder;
 import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
 import org.apache.spark.network.util.NettyUtils;
 
 public class ProtocolSuite {
   private void testServerToClient(Message msg) {
-    EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder());
+    EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
+      new MessageEncoder());
     serverChannel.writeOutbound(msg);
 
     EmbeddedChannel clientChannel = new EmbeddedChannel(
@@ -51,7 +59,8 @@ public class ProtocolSuite {
   }
 
   private void testClientToServer(Message msg) {
-    EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder());
+    EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
+      new MessageEncoder());
     clientChannel.writeOutbound(msg);
 
     EmbeddedChannel serverChannel = new EmbeddedChannel(
@@ -83,4 +92,25 @@ public class ProtocolSuite {
     testServerToClient(new RpcFailure(0, "this is an error"));
     testServerToClient(new RpcFailure(0, ""));
   }
+
+  /**
+   * Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer
+   * bytes, but messages, so this is needed so that the frame decoder on the receiving side can
+   * understand what MessageWithHeader actually contains.
+   */
+  private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegion> {
+
+    @Override
+    public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
+      throws Exception {
+
+      ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
+      while (in.transfered() < in.count()) {
+        in.transferTo(channel, in.transfered());
+      }
+      out.add(Unpooled.wrappedBuffer(channel.getData()));
+    }
+
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f084c5de/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
new file mode 100644
index 0000000..ff98509
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.FileRegion;
+import io.netty.util.AbstractReferenceCounted;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.ByteArrayWritableChannel;
+
+public class MessageWithHeaderSuite {
+
+  @Test
+  public void testSingleWrite() throws Exception {
+    testFileRegionBody(8, 8);
+  }
+
+  @Test
+  public void testShortWrite() throws Exception {
+    testFileRegionBody(8, 1);
+  }
+
+  @Test
+  public void testByteBufBody() throws Exception {
+    ByteBuf header = Unpooled.copyLong(42);
+    ByteBuf body = Unpooled.copyLong(84);
+    MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes());
+
+    ByteBuf result = doWrite(msg, 1);
+    assertEquals(msg.count(), result.readableBytes());
+    assertEquals(42, result.readLong());
+    assertEquals(84, result.readLong());
+  }
+
+  private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
+    ByteBuf header = Unpooled.copyLong(42);
+    int headerLength = header.readableBytes();
+    TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
+    MessageWithHeader msg = new MessageWithHeader(header, region, region.count());
+
+    ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
+    assertEquals(headerLength + region.count(), result.readableBytes());
+    assertEquals(42, result.readLong());
+    for (long i = 0; i < 8; i++) {
+      assertEquals(i, result.readLong());
+    }
+  }
+
+  private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
+    int writes = 0;
+    ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count());
+    while (msg.transfered() < msg.count()) {
+      msg.transferTo(channel, msg.transfered());
+      writes++;
+    }
+    assertTrue("Not enough writes!", minExpectedWrites <= writes);
+    return Unpooled.wrappedBuffer(channel.getData());
+  }
+
+  private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion {
+
+    private final int writeCount;
+    private final int writesPerCall;
+    private int written;
+
+    TestFileRegion(int totalWrites, int writesPerCall) {
+      this.writeCount = totalWrites;
+      this.writesPerCall = writesPerCall;
+    }
+
+    @Override
+    public long count() {
+      return 8 * writeCount;
+    }
+
+    @Override
+    public long position() {
+      return 0;
+    }
+
+    @Override
+    public long transfered() {
+      return 8 * written;
+    }
+
+    @Override
+    public long transferTo(WritableByteChannel target, long position) throws IOException {
+      for (int i = 0; i < writesPerCall; i++) {
+        ByteBuf buf = Unpooled.copyLong((position / 8) + i);
+        ByteBuffer nio = buf.nioBuffer();
+        while (nio.remaining() > 0) {
+          target.write(nio);
+        }
+        buf.release();
+        written++;
+      }
+      return 8 * writesPerCall;
+    }
+
+    @Override
+    protected void deallocate() {
+    }
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f084c5de/network/common/src/test/resources/log4j.properties
----------------------------------------------------------------------
diff --git a/network/common/src/test/resources/log4j.properties b/network/common/src/test/resources/log4j.properties
new file mode 100644
index 0000000..e8da774
--- /dev/null
+++ b/network/common/src/test/resources/log4j.properties
@@ -0,0 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+log4j.rootCategory=DEBUG, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Silence verbose logs from 3rd-party libraries.
+log4j.logger.io.netty=INFO


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org