You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by an...@apache.org on 2015/12/01 02:22:09 UTC

[2/2] spark git commit: [SPARK-12007][NETWORK] Avoid copies in the network lib's RPC layer.

[SPARK-12007][NETWORK] Avoid copies in the network lib's RPC layer.

This change seems large, but most of it is just replacing `byte[]`
with `ByteBuffer` and `new byte[]` with `ByteBuffer.allocate()`,
since it changes the network library's API.

The following are parts of the code that actually have meaningful
changes:

- The Message implementations were changed to inherit from a new
  AbstractMessage that can optionally hold a reference to a body
  (in the form of a ManagedBuffer); this is similar to how
  ResponseWithBody worked before, except now it's not restricted
  to just responses.

- The TransportFrameDecoder was pretty much rewritten to avoid
  copies as much as possible; it doesn't rely on CompositeByteBuf
  to accumulate incoming data anymore, since CompositeByteBuf
  has issues when slices are retained. The code now is able to
  create frames without having to resort to copying bytes except
  for a few bytes (containing the frame length) in very rare cases.

- Some minor changes in the SASL layer to convert things back to
  `byte[]` since the JDK SASL API operates on those.

Author: Marcelo Vanzin <va...@cloudera.com>

Closes #9987 from vanzin/SPARK-12007.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9bf21206
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9bf21206
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9bf21206

Branch: refs/heads/master
Commit: 9bf2120672ae0f620a217ccd96bef189ab75e0d6
Parents: 0a46e43
Author: Marcelo Vanzin <va...@cloudera.com>
Authored: Mon Nov 30 17:22:05 2015 -0800
Committer: Andrew Or <an...@databricks.com>
Committed: Mon Nov 30 17:22:05 2015 -0800

----------------------------------------------------------------------
 .../mesos/MesosExternalShuffleService.scala     |   3 +-
 .../network/netty/NettyBlockRpcServer.scala     |   8 +-
 .../netty/NettyBlockTransferService.scala       |   6 +-
 .../apache/spark/rpc/netty/NettyRpcEnv.scala    |  16 +--
 .../org/apache/spark/rpc/netty/Outbox.scala     |   9 +-
 .../spark/rpc/netty/NettyRpcHandlerSuite.scala  |   3 +-
 .../network/client/RpcResponseCallback.java     |   4 +-
 .../spark/network/client/TransportClient.java   |  16 ++-
 .../client/TransportResponseHandler.java        |  16 ++-
 .../spark/network/protocol/AbstractMessage.java |  54 +++++++
 .../protocol/AbstractResponseMessage.java       |  32 +++++
 .../network/protocol/ChunkFetchFailure.java     |   2 +-
 .../network/protocol/ChunkFetchRequest.java     |   2 +-
 .../network/protocol/ChunkFetchSuccess.java     |   8 +-
 .../apache/spark/network/protocol/Message.java  |  11 +-
 .../spark/network/protocol/MessageEncoder.java  |  29 ++--
 .../spark/network/protocol/OneWayMessage.java   |  33 +++--
 .../network/protocol/ResponseWithBody.java      |  40 ------
 .../spark/network/protocol/RpcFailure.java      |   2 +-
 .../spark/network/protocol/RpcRequest.java      |  34 +++--
 .../spark/network/protocol/RpcResponse.java     |  39 +++--
 .../spark/network/protocol/StreamFailure.java   |   2 +-
 .../spark/network/protocol/StreamRequest.java   |   2 +-
 .../spark/network/protocol/StreamResponse.java  |  19 +--
 .../spark/network/sasl/SaslClientBootstrap.java |  12 +-
 .../apache/spark/network/sasl/SaslMessage.java  |  31 ++--
 .../spark/network/sasl/SaslRpcHandler.java      |  26 +++-
 .../spark/network/server/MessageHandler.java    |   2 +-
 .../spark/network/server/NoOpRpcHandler.java    |   8 +-
 .../apache/spark/network/server/RpcHandler.java |   8 +-
 .../network/server/TransportChannelHandler.java |   2 +-
 .../network/server/TransportRequestHandler.java |  15 +-
 .../apache/spark/network/util/JavaUtils.java    |  48 ++++---
 .../network/util/TransportFrameDecoder.java     | 142 +++++++++++++------
 .../network/ChunkFetchIntegrationSuite.java     |   5 +-
 .../org/apache/spark/network/ProtocolSuite.java |  10 +-
 .../network/RequestTimeoutIntegrationSuite.java |  51 ++++---
 .../spark/network/RpcIntegrationSuite.java      |  26 ++--
 .../org/apache/spark/network/StreamSuite.java   |   5 +-
 .../network/TransportResponseHandlerSuite.java  |  24 ++--
 .../spark/network/sasl/SparkSaslSuite.java      |  43 ++++--
 .../util/TransportFrameDecoderSuite.java        |  23 ++-
 .../shuffle/ExternalShuffleBlockHandler.java    |   9 +-
 .../network/shuffle/ExternalShuffleClient.java  |   3 +-
 .../network/shuffle/OneForOneBlockFetcher.java  |   7 +-
 .../mesos/MesosExternalShuffleClient.java       |   5 +-
 .../shuffle/protocol/BlockTransferMessage.java  |   8 +-
 .../network/sasl/SaslIntegrationSuite.java      |  18 +--
 .../shuffle/BlockTransferMessagesSuite.java     |   2 +-
 .../ExternalShuffleBlockHandlerSuite.java       |  21 +--
 .../shuffle/OneForOneBlockFetcherSuite.java     |   8 +-
 51 files changed, 617 insertions(+), 335 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
index 12337a9..8ffcfc0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.deploy.mesos
 
 import java.net.SocketAddress
+import java.nio.ByteBuffer
 
 import scala.collection.mutable
 
@@ -56,7 +57,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo
           }
         }
         connectedApps(address) = appId
-        callback.onSuccess(new Array[Byte](0))
+        callback.onSuccess(ByteBuffer.allocate(0))
       case _ => super.handleMessage(message, client, callback)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 7696824..df8c21f 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -47,9 +47,9 @@ class NettyBlockRpcServer(
 
   override def receive(
       client: TransportClient,
-      messageBytes: Array[Byte],
+      rpcMessage: ByteBuffer,
       responseContext: RpcResponseCallback): Unit = {
-    val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
+    val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
     logTrace(s"Received request: $message")
 
     message match {
@@ -58,7 +58,7 @@ class NettyBlockRpcServer(
           openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
         val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
         logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
-        responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
+        responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)
 
       case uploadBlock: UploadBlock =>
         // StorageLevel is serialized as bytes using our JavaSerializer.
@@ -66,7 +66,7 @@ class NettyBlockRpcServer(
           serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
         val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
         blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
-        responseContext.onSuccess(new Array[Byte](0))
+        responseContext.onSuccess(ByteBuffer.allocate(0))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index b0694e3..82c16e8 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.netty
 
+import java.nio.ByteBuffer
+
 import scala.collection.JavaConverters._
 import scala.concurrent.{Future, Promise}
 
@@ -133,9 +135,9 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
       data
     }
 
-    client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
+    client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
       new RpcResponseCallback {
-        override def onSuccess(response: Array[Byte]): Unit = {
+        override def onSuccess(response: ByteBuffer): Unit = {
           logTrace(s"Successfully uploaded block $blockId")
           result.success((): Unit)
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index c7d74fa..68c5f44 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -241,16 +241,14 @@ private[netty] class NettyRpcEnv(
     promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
   }
 
-  private[netty] def serialize(content: Any): Array[Byte] = {
-    val buffer = javaSerializerInstance.serialize(content)
-    java.util.Arrays.copyOfRange(
-      buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
+  private[netty] def serialize(content: Any): ByteBuffer = {
+    javaSerializerInstance.serialize(content)
   }
 
-  private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = {
+  private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
     NettyRpcEnv.currentClient.withValue(client) {
       deserialize { () =>
-        javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
+        javaSerializerInstance.deserialize[T](bytes)
       }
     }
   }
@@ -557,7 +555,7 @@ private[netty] class NettyRpcHandler(
 
   override def receive(
       client: TransportClient,
-      message: Array[Byte],
+      message: ByteBuffer,
       callback: RpcResponseCallback): Unit = {
     val messageToDispatch = internalReceive(client, message)
     dispatcher.postRemoteMessage(messageToDispatch, callback)
@@ -565,12 +563,12 @@ private[netty] class NettyRpcHandler(
 
   override def receive(
       client: TransportClient,
-      message: Array[Byte]): Unit = {
+      message: ByteBuffer): Unit = {
     val messageToDispatch = internalReceive(client, message)
     dispatcher.postOneWayMessage(messageToDispatch)
   }
 
-  private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = {
+  private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = {
     val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
     assert(addr != null)
     val clientAddr = RpcAddress(addr.getHostName, addr.getPort)

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
index 36fdd00..2316ebe 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.rpc.netty
 
+import java.nio.ByteBuffer
 import java.util.concurrent.Callable
 import javax.annotation.concurrent.GuardedBy
 
@@ -34,7 +35,7 @@ private[netty] sealed trait OutboxMessage {
 
 }
 
-private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage
+private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage
   with Logging {
 
   override def sendWith(client: TransportClient): Unit = {
@@ -48,9 +49,9 @@ private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends Outb
 }
 
 private[netty] case class RpcOutboxMessage(
-    content: Array[Byte],
+    content: ByteBuffer,
     _onFailure: (Throwable) => Unit,
-    _onSuccess: (TransportClient, Array[Byte]) => Unit)
+    _onSuccess: (TransportClient, ByteBuffer) => Unit)
   extends OutboxMessage with RpcResponseCallback {
 
   private var client: TransportClient = _
@@ -70,7 +71,7 @@ private[netty] case class RpcOutboxMessage(
     _onFailure(e)
   }
 
-  override def onSuccess(response: Array[Byte]): Unit = {
+  override def onSuccess(response: ByteBuffer): Unit = {
     _onSuccess(client, response)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index 323184c..ebd6f70 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.rpc.netty
 
 import java.net.InetSocketAddress
+import java.nio.ByteBuffer
 
 import io.netty.channel.Channel
 import org.mockito.Mockito._
@@ -32,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
 
   val env = mock(classOf[NettyRpcEnv])
   val sm = mock(classOf[StreamManager])
-  when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
+  when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any()))
     .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))
 
   test("receive") {

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
index 6ec960d..47e93f9 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
@@ -17,13 +17,15 @@
 
 package org.apache.spark.network.client;
 
+import java.nio.ByteBuffer;
+
 /**
  * Callback for the result of a single RPC. This will be invoked once with either success or
  * failure.
  */
 public interface RpcResponseCallback {
   /** Successful serialized result from server. */
-  void onSuccess(byte[] response);
+  void onSuccess(ByteBuffer response);
 
   /** Exception either propagated from server or raised on client side. */
   void onFailure(Throwable e);

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 8a58e7b..c49ca4d 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -20,6 +20,7 @@ package org.apache.spark.network.client;
 import java.io.Closeable;
 import java.io.IOException;
 import java.net.SocketAddress;
+import java.nio.ByteBuffer;
 import java.util.UUID;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
@@ -36,6 +37,7 @@ import io.netty.channel.ChannelFutureListener;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.spark.network.buffer.NioManagedBuffer;
 import org.apache.spark.network.protocol.ChunkFetchRequest;
 import org.apache.spark.network.protocol.OneWayMessage;
 import org.apache.spark.network.protocol.RpcRequest;
@@ -212,7 +214,7 @@ public class TransportClient implements Closeable {
    * @param callback Callback to handle the RPC's reply.
    * @return The RPC's id.
    */
-  public long sendRpc(byte[] message, final RpcResponseCallback callback) {
+  public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) {
     final String serverAddr = NettyUtils.getRemoteAddress(channel);
     final long startTime = System.currentTimeMillis();
     logger.trace("Sending RPC to {}", serverAddr);
@@ -220,7 +222,7 @@ public class TransportClient implements Closeable {
     final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
     handler.addRpcRequest(requestId, callback);
 
-    channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
+    channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener(
       new ChannelFutureListener() {
         @Override
         public void operationComplete(ChannelFuture future) throws Exception {
@@ -249,12 +251,12 @@ public class TransportClient implements Closeable {
    * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
    * a specified timeout for a response.
    */
-  public byte[] sendRpcSync(byte[] message, long timeoutMs) {
-    final SettableFuture<byte[]> result = SettableFuture.create();
+  public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
+    final SettableFuture<ByteBuffer> result = SettableFuture.create();
 
     sendRpc(message, new RpcResponseCallback() {
       @Override
-      public void onSuccess(byte[] response) {
+      public void onSuccess(ByteBuffer response) {
         result.set(response);
       }
 
@@ -279,8 +281,8 @@ public class TransportClient implements Closeable {
    *
    * @param message The message to send.
    */
-  public void send(byte[] message) {
-    channel.writeAndFlush(new OneWayMessage(message));
+  public void send(ByteBuffer message) {
+    channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index 4c15045..23a8dba 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -136,7 +136,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
   }
 
   @Override
-  public void handle(ResponseMessage message) {
+  public void handle(ResponseMessage message) throws Exception {
     String remoteAddress = NettyUtils.getRemoteAddress(channel);
     if (message instanceof ChunkFetchSuccess) {
       ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
@@ -144,11 +144,11 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
       if (listener == null) {
         logger.warn("Ignoring response for block {} from {} since it is not outstanding",
           resp.streamChunkId, remoteAddress);
-        resp.body.release();
+        resp.body().release();
       } else {
         outstandingFetches.remove(resp.streamChunkId);
-        listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body);
-        resp.body.release();
+        listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
+        resp.body().release();
       }
     } else if (message instanceof ChunkFetchFailure) {
       ChunkFetchFailure resp = (ChunkFetchFailure) message;
@@ -166,10 +166,14 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
       RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
       if (listener == null) {
         logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
-          resp.requestId, remoteAddress, resp.response.length);
+          resp.requestId, remoteAddress, resp.body().size());
       } else {
         outstandingRpcs.remove(resp.requestId);
-        listener.onSuccess(resp.response);
+        try {
+          listener.onSuccess(resp.body().nioByteBuffer());
+        } finally {
+          resp.body().release();
+        }
       }
     } else if (message instanceof RpcFailure) {
       RpcFailure resp = (RpcFailure) message;

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
new file mode 100644
index 0000000..2924218
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractMessage.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Abstract class for messages which optionally contain a body kept in a separate buffer.
+ */
+public abstract class AbstractMessage implements Message {
+  private final ManagedBuffer body;
+  private final boolean isBodyInFrame;
+
+  protected AbstractMessage() {
+    this(null, false);
+  }
+
+  protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) {
+    this.body = body;
+    this.isBodyInFrame = isBodyInFrame;
+  }
+
+  @Override
+  public ManagedBuffer body() {
+    return body;
+  }
+
+  @Override
+  public boolean isBodyInFrame() {
+    return isBodyInFrame;
+  }
+
+  protected boolean equals(AbstractMessage other) {
+    return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
new file mode 100644
index 0000000..c362c92
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/AbstractResponseMessage.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Abstract class for response messages.
+ */
+public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage {
+
+  protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) {
+    super(body, isBodyInFrame);
+  }
+
+  public abstract ResponseMessage createFailureResponse(String error);
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
index f036383..7b28a9a 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -23,7 +23,7 @@ import io.netty.buffer.ByteBuf;
 /**
  * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
  */
-public final class ChunkFetchFailure implements ResponseMessage {
+public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage {
   public final StreamChunkId streamChunkId;
   public final String errorString;
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
index 5a173af..26d063f 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -24,7 +24,7 @@ import io.netty.buffer.ByteBuf;
  * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
  * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
  */
-public final class ChunkFetchRequest implements RequestMessage {
+public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage {
   public final StreamChunkId streamChunkId;
 
   public ChunkFetchRequest(StreamChunkId streamChunkId) {

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
index e6a7e9a..94c2ac9 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -30,7 +30,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
  * may be written by Netty in a more efficient manner (i.e., zero-copy write).
  * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
  */
-public final class ChunkFetchSuccess extends ResponseWithBody {
+public final class ChunkFetchSuccess extends AbstractResponseMessage {
   public final StreamChunkId streamChunkId;
 
   public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
@@ -67,14 +67,14 @@ public final class ChunkFetchSuccess extends ResponseWithBody {
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(streamChunkId, body);
+    return Objects.hashCode(streamChunkId, body());
   }
 
   @Override
   public boolean equals(Object other) {
     if (other instanceof ChunkFetchSuccess) {
       ChunkFetchSuccess o = (ChunkFetchSuccess) other;
-      return streamChunkId.equals(o.streamChunkId) && body.equals(o.body);
+      return streamChunkId.equals(o.streamChunkId) && super.equals(o);
     }
     return false;
   }
@@ -83,7 +83,7 @@ public final class ChunkFetchSuccess extends ResponseWithBody {
   public String toString() {
     return Objects.toStringHelper(this)
       .add("streamChunkId", streamChunkId)
-      .add("buffer", body)
+      .add("buffer", body())
       .toString();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
index 39afd03..66f5b8b 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -19,17 +19,25 @@ package org.apache.spark.network.protocol;
 
 import io.netty.buffer.ByteBuf;
 
+import org.apache.spark.network.buffer.ManagedBuffer;
+
 /** An on-the-wire transmittable message. */
 public interface Message extends Encodable {
   /** Used to identify this request type. */
   Type type();
 
+  /** An optional body for the message. */
+  ManagedBuffer body();
+
+  /** Whether to include the body of the message in the same frame as the message. */
+  boolean isBodyInFrame();
+
   /** Preceding every serialized Message is its type, which allows us to deserialize it. */
   public static enum Type implements Encodable {
     ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
     RpcRequest(3), RpcResponse(4), RpcFailure(5),
     StreamRequest(6), StreamResponse(7), StreamFailure(8),
-    OneWayMessage(9);
+    OneWayMessage(9), User(-1);
 
     private final byte id;
 
@@ -57,6 +65,7 @@ public interface Message extends Encodable {
         case 7: return StreamResponse;
         case 8: return StreamFailure;
         case 9: return OneWayMessage;
+        case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
         default: throw new IllegalArgumentException("Unknown message type: " + id);
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 6cce97c..abca223 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -42,25 +42,28 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
    * data to 'out', in order to enable zero-copy transfer.
    */
   @Override
-  public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
+  public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) throws Exception {
     Object body = null;
     long bodyLength = 0;
     boolean isBodyInFrame = false;
 
-    // Detect ResponseWithBody messages and get the data buffer out of them.
-    // The body is used in order to enable zero-copy transfer for the payload.
-    if (in instanceof ResponseWithBody) {
-      ResponseWithBody resp = (ResponseWithBody) in;
+    // If the message has a body, take it out to enable zero-copy transfer for the payload.
+    if (in.body() != null) {
       try {
-        bodyLength = resp.body.size();
-        body = resp.body.convertToNetty();
-        isBodyInFrame = resp.isBodyInFrame;
+        bodyLength = in.body().size();
+        body = in.body().convertToNetty();
+        isBodyInFrame = in.isBodyInFrame();
       } catch (Exception e) {
-        // Re-encode this message as a failure response.
-        String error = e.getMessage() != null ? e.getMessage() : "null";
-        logger.error(String.format("Error processing %s for client %s",
-          resp, ctx.channel().remoteAddress()), e);
-        encode(ctx, resp.createFailureResponse(error), out);
+        if (in instanceof AbstractResponseMessage) {
+          AbstractResponseMessage resp = (AbstractResponseMessage) in;
+          // Re-encode this message as a failure response.
+          String error = e.getMessage() != null ? e.getMessage() : "null";
+          logger.error(String.format("Error processing %s for client %s",
+            in, ctx.channel().remoteAddress()), e);
+          encode(ctx, resp.createFailureResponse(error), out);
+        } else {
+          throw e;
+        }
         return;
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
index 95a0270..efe0470 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
@@ -17,21 +17,21 @@
 
 package org.apache.spark.network.protocol;
 
-import java.util.Arrays;
-
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
 
 /**
  * A RPC that does not expect a reply, which is handled by a remote
  * {@link org.apache.spark.network.server.RpcHandler}.
  */
-public final class OneWayMessage implements RequestMessage {
-  /** Serialized message to send to remote RpcHandler. */
-  public final byte[] message;
+public final class OneWayMessage extends AbstractMessage implements RequestMessage {
 
-  public OneWayMessage(byte[] message) {
-    this.message = message;
+  public OneWayMessage(ManagedBuffer body) {
+    super(body, true);
   }
 
   @Override
@@ -39,29 +39,34 @@ public final class OneWayMessage implements RequestMessage {
 
   @Override
   public int encodedLength() {
-    return Encoders.ByteArrays.encodedLength(message);
+    // The integer (a.k.a. the body size) is not really used, since that information is already
+    // encoded in the frame length. But this maintains backwards compatibility with versions of
+    // RpcRequest that use Encoders.ByteArrays.
+    return 4;
   }
 
   @Override
   public void encode(ByteBuf buf) {
-    Encoders.ByteArrays.encode(buf, message);
+    // See comment in encodedLength().
+    buf.writeInt((int) body().size());
   }
 
   public static OneWayMessage decode(ByteBuf buf) {
-    byte[] message = Encoders.ByteArrays.decode(buf);
-    return new OneWayMessage(message);
+    // See comment in encodedLength().
+    buf.readInt();
+    return new OneWayMessage(new NettyManagedBuffer(buf.retain()));
   }
 
   @Override
   public int hashCode() {
-    return Arrays.hashCode(message);
+    return Objects.hashCode(body());
   }
 
   @Override
   public boolean equals(Object other) {
     if (other instanceof OneWayMessage) {
       OneWayMessage o = (OneWayMessage) other;
-      return Arrays.equals(message, o.message);
+      return super.equals(o);
     }
     return false;
   }
@@ -69,7 +74,7 @@ public final class OneWayMessage implements RequestMessage {
   @Override
   public String toString() {
     return Objects.toStringHelper(this)
-      .add("message", message)
+      .add("body", body())
       .toString();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
deleted file mode 100644
index 67be77e..0000000
--- a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.protocol;
-
-import com.google.common.base.Objects;
-import io.netty.buffer.ByteBuf;
-
-import org.apache.spark.network.buffer.ManagedBuffer;
-import org.apache.spark.network.buffer.NettyManagedBuffer;
-
-/**
- * Abstract class for response messages that contain a large data portion kept in a separate
- * buffer. These messages are treated especially by MessageEncoder.
- */
-public abstract class ResponseWithBody implements ResponseMessage {
-  public final ManagedBuffer body;
-  public final boolean isBodyInFrame;
-
-  protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) {
-    this.body = body;
-    this.isBodyInFrame = isBodyInFrame;
-  }
-
-  public abstract ResponseMessage createFailureResponse(String error);
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
index 2dfc787..a76624e 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java
@@ -21,7 +21,7 @@ import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
 
 /** Response to {@link RpcRequest} for a failed RPC. */
-public final class RpcFailure implements ResponseMessage {
+public final class RpcFailure extends AbstractMessage implements ResponseMessage {
   public final long requestId;
   public final String errorString;
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
index 745039d..9621379 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java
@@ -17,26 +17,25 @@
 
 package org.apache.spark.network.protocol;
 
-import java.util.Arrays;
-
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
 
 /**
  * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}.
  * This will correspond to a single
  * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
  */
-public final class RpcRequest implements RequestMessage {
+public final class RpcRequest extends AbstractMessage implements RequestMessage {
   /** Used to link an RPC request with its response. */
   public final long requestId;
 
-  /** Serialized message to send to remote RpcHandler. */
-  public final byte[] message;
-
-  public RpcRequest(long requestId, byte[] message) {
+  public RpcRequest(long requestId, ManagedBuffer message) {
+    super(message, true);
     this.requestId = requestId;
-    this.message = message;
   }
 
   @Override
@@ -44,31 +43,36 @@ public final class RpcRequest implements RequestMessage {
 
   @Override
   public int encodedLength() {
-    return 8 + Encoders.ByteArrays.encodedLength(message);
+    // The integer (a.k.a. the body size) is not really used, since that information is already
+    // encoded in the frame length. But this maintains backwards compatibility with versions of
+    // RpcRequest that use Encoders.ByteArrays.
+    return 8 + 4;
   }
 
   @Override
   public void encode(ByteBuf buf) {
     buf.writeLong(requestId);
-    Encoders.ByteArrays.encode(buf, message);
+    // See comment in encodedLength().
+    buf.writeInt((int) body().size());
   }
 
   public static RpcRequest decode(ByteBuf buf) {
     long requestId = buf.readLong();
-    byte[] message = Encoders.ByteArrays.decode(buf);
-    return new RpcRequest(requestId, message);
+    // See comment in encodedLength().
+    buf.readInt();
+    return new RpcRequest(requestId, new NettyManagedBuffer(buf.retain()));
   }
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(requestId, Arrays.hashCode(message));
+    return Objects.hashCode(requestId, body());
   }
 
   @Override
   public boolean equals(Object other) {
     if (other instanceof RpcRequest) {
       RpcRequest o = (RpcRequest) other;
-      return requestId == o.requestId && Arrays.equals(message, o.message);
+      return requestId == o.requestId && super.equals(o);
     }
     return false;
   }
@@ -77,7 +81,7 @@ public final class RpcRequest implements RequestMessage {
   public String toString() {
     return Objects.toStringHelper(this)
       .add("requestId", requestId)
-      .add("message", message)
+      .add("body", body())
       .toString();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
index 1671cd4..bae866e 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java
@@ -17,49 +17,62 @@
 
 package org.apache.spark.network.protocol;
 
-import java.util.Arrays;
-
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
 
 /** Response to {@link RpcRequest} for a successful RPC. */
-public final class RpcResponse implements ResponseMessage {
+public final class RpcResponse extends AbstractResponseMessage {
   public final long requestId;
-  public final byte[] response;
 
-  public RpcResponse(long requestId, byte[] response) {
+  public RpcResponse(long requestId, ManagedBuffer message) {
+    super(message, true);
     this.requestId = requestId;
-    this.response = response;
   }
 
   @Override
   public Type type() { return Type.RpcResponse; }
 
   @Override
-  public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); }
+  public int encodedLength() {
+    // The integer (a.k.a. the body size) is not really used, since that information is already
+    // encoded in the frame length. But this maintains backwards compatibility with versions of
+    // RpcRequest that use Encoders.ByteArrays.
+    return 8 + 4;
+  }
 
   @Override
   public void encode(ByteBuf buf) {
     buf.writeLong(requestId);
-    Encoders.ByteArrays.encode(buf, response);
+    // See comment in encodedLength().
+    buf.writeInt((int) body().size());
+  }
+
+  @Override
+  public ResponseMessage createFailureResponse(String error) {
+    return new RpcFailure(requestId, error);
   }
 
   public static RpcResponse decode(ByteBuf buf) {
     long requestId = buf.readLong();
-    byte[] response = Encoders.ByteArrays.decode(buf);
-    return new RpcResponse(requestId, response);
+    // See comment in encodedLength().
+    buf.readInt();
+    return new RpcResponse(requestId, new NettyManagedBuffer(buf.retain()));
   }
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(requestId, Arrays.hashCode(response));
+    return Objects.hashCode(requestId, body());
   }
 
   @Override
   public boolean equals(Object other) {
     if (other instanceof RpcResponse) {
       RpcResponse o = (RpcResponse) other;
-      return requestId == o.requestId && Arrays.equals(response, o.response);
+      return requestId == o.requestId && super.equals(o);
     }
     return false;
   }
@@ -68,7 +81,7 @@ public final class RpcResponse implements ResponseMessage {
   public String toString() {
     return Objects.toStringHelper(this)
       .add("requestId", requestId)
-      .add("response", response)
+      .add("body", body())
       .toString();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
index e3dade2..26747ee 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
@@ -26,7 +26,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
 /**
  * Message indicating an error when transferring a stream.
  */
-public final class StreamFailure implements ResponseMessage {
+public final class StreamFailure extends AbstractMessage implements ResponseMessage {
   public final String streamId;
   public final String error;
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
index 821e8f5..35af5a8 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
@@ -29,7 +29,7 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
  * The stream ID is an arbitrary string that needs to be negotiated between the two endpoints before
  * the data can be streamed.
  */
-public final class StreamRequest implements RequestMessage {
+public final class StreamRequest extends AbstractMessage implements RequestMessage {
    public final String streamId;
 
    public StreamRequest(String streamId) {

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
index ac5ab9a..51b8999 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
@@ -30,15 +30,15 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
  * sender. The receiver is expected to set a temporary channel handler that will consume the
  * number of bytes this message says the stream has.
  */
-public final class StreamResponse extends ResponseWithBody {
-   public final String streamId;
-   public final long byteCount;
+public final class StreamResponse extends AbstractResponseMessage {
+  public final String streamId;
+  public final long byteCount;
 
-   public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
-     super(buffer, false);
-     this.streamId = streamId;
-     this.byteCount = byteCount;
-   }
+  public StreamResponse(String streamId, long byteCount, ManagedBuffer buffer) {
+    super(buffer, false);
+    this.streamId = streamId;
+    this.byteCount = byteCount;
+  }
 
   @Override
   public Type type() { return Type.StreamResponse; }
@@ -68,7 +68,7 @@ public final class StreamResponse extends ResponseWithBody {
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(byteCount, streamId);
+    return Objects.hashCode(byteCount, streamId, body());
   }
 
   @Override
@@ -85,6 +85,7 @@ public final class StreamResponse extends ResponseWithBody {
     return Objects.toStringHelper(this)
       .add("streamId", streamId)
       .add("byteCount", byteCount)
+      .add("body", body())
       .toString();
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
index 6992376..6838103 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.sasl;
 
+import java.io.IOException;
+import java.nio.ByteBuffer;
 import javax.security.sasl.Sasl;
 import javax.security.sasl.SaslException;
 
@@ -28,6 +30,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.client.TransportClientBootstrap;
+import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.TransportConf;
 
 /**
@@ -70,11 +73,12 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
 
       while (!saslClient.isComplete()) {
         SaslMessage msg = new SaslMessage(appId, payload);
-        ByteBuf buf = Unpooled.buffer(msg.encodedLength());
+        ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
         msg.encode(buf);
+        buf.writeBytes(msg.body().nioByteBuffer());
 
-        byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs());
-        payload = saslClient.response(response);
+        ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
+        payload = saslClient.response(JavaUtils.bufferToArray(response));
       }
 
       client.setClientId(appId);
@@ -88,6 +92,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
         saslClient = null;
         logger.debug("Channel {} configured for SASL encryption.", client);
       }
+    } catch (IOException ioe) {
+      throw new RuntimeException(ioe);
     } finally {
       if (saslClient != null) {
         try {

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
index cad76ab..e52b526 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslMessage.java
@@ -18,38 +18,50 @@
 package org.apache.spark.network.sasl;
 
 import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
 
-import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
 import org.apache.spark.network.protocol.Encoders;
+import org.apache.spark.network.protocol.AbstractMessage;
 
 /**
  * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged
  * with the given appId. This appId allows a single SaslRpcHandler to multiplex different
  * applications which may be using different sets of credentials.
  */
-class SaslMessage implements Encodable {
+class SaslMessage extends AbstractMessage {
 
   /** Serialization tag used to catch incorrect payloads. */
   private static final byte TAG_BYTE = (byte) 0xEA;
 
   public final String appId;
-  public final byte[] payload;
 
-  public SaslMessage(String appId, byte[] payload) {
+  public SaslMessage(String appId, byte[] message) {
+    this(appId, Unpooled.wrappedBuffer(message));
+  }
+
+  public SaslMessage(String appId, ByteBuf message) {
+    super(new NettyManagedBuffer(message), true);
     this.appId = appId;
-    this.payload = payload;
   }
 
   @Override
+  public Type type() { return Type.User; }
+
+  @Override
   public int encodedLength() {
-    return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload);
+    // The integer (a.k.a. the body size) is not really used, since that information is already
+    // encoded in the frame length. But this maintains backwards compatibility with versions of
+    // RpcRequest that use Encoders.ByteArrays.
+    return 1 + Encoders.Strings.encodedLength(appId) + 4;
   }
 
   @Override
   public void encode(ByteBuf buf) {
     buf.writeByte(TAG_BYTE);
     Encoders.Strings.encode(buf, appId);
-    Encoders.ByteArrays.encode(buf, payload);
+    // See comment in encodedLength().
+    buf.writeInt((int) body().size());
   }
 
   public static SaslMessage decode(ByteBuf buf) {
@@ -59,7 +71,8 @@ class SaslMessage implements Encodable {
     }
 
     String appId = Encoders.Strings.decode(buf);
-    byte[] payload = Encoders.ByteArrays.decode(buf);
-    return new SaslMessage(appId, payload);
+    // See comment in encodedLength().
+    buf.readInt();
+    return new SaslMessage(appId, buf.retain());
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index 830db94..c215bd9 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -17,8 +17,11 @@
 
 package org.apache.spark.network.sasl;
 
+import java.io.IOException;
+import java.nio.ByteBuffer;
 import javax.security.sasl.Sasl;
 
+import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 import io.netty.channel.Channel;
 import org.slf4j.Logger;
@@ -28,6 +31,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.TransportConf;
 
 /**
@@ -70,14 +74,20 @@ class SaslRpcHandler extends RpcHandler {
   }
 
   @Override
-  public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+  public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
     if (isComplete) {
       // Authentication complete, delegate to base handler.
       delegate.receive(client, message, callback);
       return;
     }
 
-    SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message));
+    ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
+    SaslMessage saslMessage;
+    try {
+      saslMessage = SaslMessage.decode(nettyBuf);
+    } finally {
+      nettyBuf.release();
+    }
 
     if (saslServer == null) {
       // First message in the handshake, setup the necessary state.
@@ -86,8 +96,14 @@ class SaslRpcHandler extends RpcHandler {
         conf.saslServerAlwaysEncrypt());
     }
 
-    byte[] response = saslServer.response(saslMessage.payload);
-    callback.onSuccess(response);
+    byte[] response;
+    try {
+      response = saslServer.response(JavaUtils.bufferToArray(
+        saslMessage.body().nioByteBuffer()));
+    } catch (IOException ioe) {
+      throw new RuntimeException(ioe);
+    }
+    callback.onSuccess(ByteBuffer.wrap(response));
 
     // Setup encryption after the SASL response is sent, otherwise the client can't parse the
     // response. It's ok to change the channel pipeline here since we are processing an incoming
@@ -109,7 +125,7 @@ class SaslRpcHandler extends RpcHandler {
   }
 
   @Override
-  public void receive(TransportClient client, byte[] message) {
+  public void receive(TransportClient client, ByteBuffer message) {
     delegate.receive(client, message);
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
index b80c151..3843406 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java
@@ -26,7 +26,7 @@ import org.apache.spark.network.protocol.Message;
  */
 public abstract class MessageHandler<T extends Message> {
   /** Handles the receipt of a single message. */
-  public abstract void handle(T message);
+  public abstract void handle(T message) throws Exception;
 
   /** Invoked when an exception was caught on the Channel. */
   public abstract void exceptionCaught(Throwable cause);

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
index 1502b74..6ed61da 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java
@@ -1,5 +1,3 @@
-package org.apache.spark.network.server;
-
 /*
  * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
@@ -17,6 +15,10 @@ package org.apache.spark.network.server;
  * limitations under the License.
  */
 
+package org.apache.spark.network.server;
+
+import java.nio.ByteBuffer;
+
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 
@@ -29,7 +31,7 @@ public class NoOpRpcHandler extends RpcHandler {
   }
 
   @Override
-  public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+  public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
     throw new UnsupportedOperationException("Cannot handle messages");
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index 1a11f7b..ee1c683 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.server;
 
+import java.nio.ByteBuffer;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -44,7 +46,7 @@ public abstract class RpcHandler {
    */
   public abstract void receive(
       TransportClient client,
-      byte[] message,
+      ByteBuffer message,
       RpcResponseCallback callback);
 
   /**
@@ -62,7 +64,7 @@ public abstract class RpcHandler {
    *               of this RPC. This will always be the exact same object for a particular channel.
    * @param message The serialized bytes of the RPC.
    */
-  public void receive(TransportClient client, byte[] message) {
+  public void receive(TransportClient client, ByteBuffer message) {
     receive(client, message, ONE_WAY_CALLBACK);
   }
 
@@ -79,7 +81,7 @@ public abstract class RpcHandler {
     private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class);
 
     @Override
-    public void onSuccess(byte[] response) {
+    public void onSuccess(ByteBuffer response) {
       logger.warn("Response provided for one-way RPC.");
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index 3164e00..09435bc 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -99,7 +99,7 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
   }
 
   @Override
-  public void channelRead0(ChannelHandlerContext ctx, Message request) {
+  public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
     if (request instanceof RequestMessage) {
       requestHandler.handle((RequestMessage) request);
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index db18ea7..c864d7c 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.server;
 
+import java.nio.ByteBuffer;
+
 import com.google.common.base.Preconditions;
 import com.google.common.base.Throwables;
 import io.netty.channel.Channel;
@@ -26,6 +28,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.protocol.ChunkFetchRequest;
@@ -143,10 +146,10 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
 
   private void processRpcRequest(final RpcRequest req) {
     try {
-      rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() {
+      rpcHandler.receive(reverseClient, req.body().nioByteBuffer(), new RpcResponseCallback() {
         @Override
-        public void onSuccess(byte[] response) {
-          respond(new RpcResponse(req.requestId, response));
+        public void onSuccess(ByteBuffer response) {
+          respond(new RpcResponse(req.requestId, new NioManagedBuffer(response)));
         }
 
         @Override
@@ -157,14 +160,18 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
     } catch (Exception e) {
       logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
       respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
+    } finally {
+      req.body().release();
     }
   }
 
   private void processOneWayMessage(OneWayMessage req) {
     try {
-      rpcHandler.receive(reverseClient, req.message);
+      rpcHandler.receive(reverseClient, req.body().nioByteBuffer());
     } catch (Exception e) {
       logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
+    } finally {
+      req.body().release();
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
index 7d27439..b3d8e0c 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java
@@ -132,7 +132,7 @@ public class JavaUtils {
     return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile());
   }
 
-  private static final ImmutableMap<String, TimeUnit> timeSuffixes = 
+  private static final ImmutableMap<String, TimeUnit> timeSuffixes =
     ImmutableMap.<String, TimeUnit>builder()
       .put("us", TimeUnit.MICROSECONDS)
       .put("ms", TimeUnit.MILLISECONDS)
@@ -164,32 +164,32 @@ public class JavaUtils {
    */
   private static long parseTimeString(String str, TimeUnit unit) {
     String lower = str.toLowerCase().trim();
-    
+
     try {
       Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower);
       if (!m.matches()) {
         throw new NumberFormatException("Failed to parse time string: " + str);
       }
-      
+
       long val = Long.parseLong(m.group(1));
       String suffix = m.group(2);
-      
+
       // Check for invalid suffixes
       if (suffix != null && !timeSuffixes.containsKey(suffix)) {
         throw new NumberFormatException("Invalid suffix: \"" + suffix + "\"");
       }
-      
+
       // If suffix is valid use that, otherwise none was provided and use the default passed
       return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit);
     } catch (NumberFormatException e) {
       String timeError = "Time must be specified as seconds (s), " +
               "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " +
               "E.g. 50s, 100ms, or 250us.";
-      
+
       throw new NumberFormatException(timeError + "\n" + e.getMessage());
     }
   }
-  
+
   /**
    * Convert a time parameter such as (50s, 100ms, or 250us) to milliseconds for internal use. If
    * no suffix is provided, the passed number is assumed to be in ms.
@@ -205,10 +205,10 @@ public class JavaUtils {
   public static long timeStringAsSec(String str) {
     return parseTimeString(str, TimeUnit.SECONDS);
   }
-  
+
   /**
    * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for
-   * internal use. If no suffix is provided a direct conversion of the provided default is 
+   * internal use. If no suffix is provided a direct conversion of the provided default is
    * attempted.
    */
   private static long parseByteString(String str, ByteUnit unit) {
@@ -217,7 +217,7 @@ public class JavaUtils {
     try {
       Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower);
       Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower);
-      
+
       if (m.matches()) {
         long val = Long.parseLong(m.group(1));
         String suffix = m.group(2);
@@ -228,14 +228,14 @@ public class JavaUtils {
         }
 
         // If suffix is valid use that, otherwise none was provided and use the default passed
-        return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit);  
+        return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit);
       } else if (fractionMatcher.matches()) {
-        throw new NumberFormatException("Fractional values are not supported. Input was: " 
+        throw new NumberFormatException("Fractional values are not supported. Input was: "
           + fractionMatcher.group(1));
       } else {
-        throw new NumberFormatException("Failed to parse byte string: " + str);  
+        throw new NumberFormatException("Failed to parse byte string: " + str);
       }
-      
+
     } catch (NumberFormatException e) {
       String timeError = "Size must be specified as bytes (b), " +
         "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " +
@@ -248,7 +248,7 @@ public class JavaUtils {
   /**
    * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for
    * internal use.
-   * 
+   *
    * If no suffix is provided, the passed number is assumed to be in bytes.
    */
   public static long byteStringAsBytes(String str) {
@@ -264,7 +264,7 @@ public class JavaUtils {
   public static long byteStringAsKb(String str) {
     return parseByteString(str, ByteUnit.KiB);
   }
-  
+
   /**
    * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for
    * internal use.
@@ -284,4 +284,20 @@ public class JavaUtils {
   public static long byteStringAsGb(String str) {
     return parseByteString(str, ByteUnit.GiB);
   }
+
+  /**
+   * Returns a byte array with the buffer's contents, trying to avoid copying the data if
+   * possible.
+   */
+  public static byte[] bufferToArray(ByteBuffer buffer) {
+    if (buffer.hasArray() && buffer.arrayOffset() == 0 &&
+        buffer.array().length == buffer.remaining()) {
+      return buffer.array();
+    } else {
+      byte[] bytes = new byte[buffer.remaining()];
+      buffer.get(bytes);
+      return bytes;
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
----------------------------------------------------------------------
diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
index 5889562..a466c72 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -17,9 +17,13 @@
 
 package org.apache.spark.network.util;
 
+import java.util.Iterator;
+import java.util.LinkedList;
+
 import com.google.common.base.Preconditions;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelInboundHandlerAdapter;
 
@@ -44,84 +48,138 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
   public static final String HANDLER_NAME = "frameDecoder";
   private static final int LENGTH_SIZE = 8;
   private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
+  private static final int UNKNOWN_FRAME_SIZE = -1;
+
+  private final LinkedList<ByteBuf> buffers = new LinkedList<>();
+  private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE);
 
-  private CompositeByteBuf buffer;
+  private long totalSize = 0;
+  private long nextFrameSize = UNKNOWN_FRAME_SIZE;
   private volatile Interceptor interceptor;
 
   @Override
   public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
     ByteBuf in = (ByteBuf) data;
+    buffers.add(in);
+    totalSize += in.readableBytes();
+
+    while (!buffers.isEmpty()) {
+      // First, feed the interceptor, and if it's still, active, try again.
+      if (interceptor != null) {
+        ByteBuf first = buffers.getFirst();
+        int available = first.readableBytes();
+        if (feedInterceptor(first)) {
+          assert !first.isReadable() : "Interceptor still active but buffer has data.";
+        }
 
-    if (buffer == null) {
-      buffer = in.alloc().compositeBuffer();
-    }
-
-    buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes());
-
-    while (buffer.isReadable()) {
-      discardReadBytes();
-      if (!feedInterceptor()) {
+        int read = available - first.readableBytes();
+        if (read == available) {
+          buffers.removeFirst().release();
+        }
+        totalSize -= read;
+      } else {
+        // Interceptor is not active, so try to decode one frame.
         ByteBuf frame = decodeNext();
         if (frame == null) {
           break;
         }
-
         ctx.fireChannelRead(frame);
       }
     }
-
-    discardReadBytes();
   }
 
-  private void discardReadBytes() {
-    // If the buffer's been retained by downstream code, then make a copy of the remaining
-    // bytes into a new buffer. Otherwise, just discard stale components.
-    if (buffer.refCnt() > 1) {
-      CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer();
+  private long decodeFrameSize() {
+    if (nextFrameSize != UNKNOWN_FRAME_SIZE || totalSize < LENGTH_SIZE) {
+      return nextFrameSize;
+    }
 
-      if (buffer.readableBytes() > 0) {
-        ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes());
-        spillBuf.writeBytes(buffer);
-        newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes());
+    // We know there's enough data. If the first buffer contains all the data, great. Otherwise,
+    // hold the bytes for the frame length in a composite buffer until we have enough data to read
+    // the frame size. Normally, it should be rare to need more than one buffer to read the frame
+    // size.
+    ByteBuf first = buffers.getFirst();
+    if (first.readableBytes() >= LENGTH_SIZE) {
+      nextFrameSize = first.readLong() - LENGTH_SIZE;
+      totalSize -= LENGTH_SIZE;
+      if (!first.isReadable()) {
+        buffers.removeFirst().release();
       }
+      return nextFrameSize;
+    }
 
-      buffer.release();
-      buffer = newBuffer;
-    } else {
-      buffer.discardReadComponents();
+    while (frameLenBuf.readableBytes() < LENGTH_SIZE) {
+      ByteBuf next = buffers.getFirst();
+      int toRead = Math.min(next.readableBytes(), LENGTH_SIZE - frameLenBuf.readableBytes());
+      frameLenBuf.writeBytes(next, toRead);
+      if (!next.isReadable()) {
+        buffers.removeFirst().release();
+      }
     }
+
+    nextFrameSize = frameLenBuf.readLong() - LENGTH_SIZE;
+    totalSize -= LENGTH_SIZE;
+    frameLenBuf.clear();
+    return nextFrameSize;
   }
 
   private ByteBuf decodeNext() throws Exception {
-    if (buffer.readableBytes() < LENGTH_SIZE) {
+    long frameSize = decodeFrameSize();
+    if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
       return null;
     }
 
-    int frameLen = (int) buffer.readLong() - LENGTH_SIZE;
-    if (buffer.readableBytes() < frameLen) {
-      buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE);
-      return null;
+    // Reset size for next frame.
+    nextFrameSize = UNKNOWN_FRAME_SIZE;
+
+    Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize);
+    Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize);
+
+    // If the first buffer holds the entire frame, return it.
+    int remaining = (int) frameSize;
+    if (buffers.getFirst().readableBytes() >= remaining) {
+      return nextBufferForFrame(remaining);
     }
 
-    Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: %s", frameLen);
-    Preconditions.checkArgument(frameLen > 0, "Frame length should be positive: %s", frameLen);
+    // Otherwise, create a composite buffer.
+    CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer();
+    while (remaining > 0) {
+      ByteBuf next = nextBufferForFrame(remaining);
+      remaining -= next.readableBytes();
+      frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes());
+    }
+    assert remaining == 0;
+    return frame;
+  }
+
+  /**
+   * Takes the first buffer in the internal list, and either adjust it to fit in the frame
+   * (by taking a slice out of it) or remove it from the internal list.
+   */
+  private ByteBuf nextBufferForFrame(int bytesToRead) {
+    ByteBuf buf = buffers.getFirst();
+    ByteBuf frame;
+
+    if (buf.readableBytes() > bytesToRead) {
+      frame = buf.retain().readSlice(bytesToRead);
+      totalSize -= bytesToRead;
+    } else {
+      frame = buf;
+      buffers.removeFirst();
+      totalSize -= frame.readableBytes();
+    }
 
-    ByteBuf frame = buffer.readSlice(frameLen);
-    frame.retain();
     return frame;
   }
 
   @Override
   public void channelInactive(ChannelHandlerContext ctx) throws Exception {
-    if (buffer != null) {
-      if (buffer.isReadable()) {
-        feedInterceptor();
-      }
-      buffer.release();
+    for (ByteBuf b : buffers) {
+      b.release();
     }
     if (interceptor != null) {
       interceptor.channelInactive();
     }
+    frameLenBuf.release();
     super.channelInactive(ctx);
   }
 
@@ -141,8 +199,8 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
   /**
    * @return Whether the interceptor is still active after processing the data.
    */
-  private boolean feedInterceptor() throws Exception {
-    if (interceptor != null && !interceptor.handle(buffer)) {
+  private boolean feedInterceptor(ByteBuf buf) throws Exception {
+    if (interceptor != null && !interceptor.handle(buf)) {
       interceptor = null;
     }
     return interceptor != null;

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
index 50a324e..70c849d 100644
--- a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java
@@ -107,7 +107,10 @@ public class ChunkFetchIntegrationSuite {
     };
     RpcHandler handler = new RpcHandler() {
       @Override
-      public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
         throw new UnsupportedOperationException();
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 1aa2090..6c8dd74 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -82,10 +82,10 @@ public class ProtocolSuite {
   @Test
   public void requests() {
     testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
-    testClientToServer(new RpcRequest(12345, new byte[0]));
-    testClientToServer(new RpcRequest(12345, new byte[100]));
+    testClientToServer(new RpcRequest(12345, new TestManagedBuffer(0)));
+    testClientToServer(new RpcRequest(12345, new TestManagedBuffer(10)));
     testClientToServer(new StreamRequest("abcde"));
-    testClientToServer(new OneWayMessage(new byte[100]));
+    testClientToServer(new OneWayMessage(new TestManagedBuffer(10)));
   }
 
   @Test
@@ -94,8 +94,8 @@ public class ProtocolSuite {
     testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0)));
     testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error"));
     testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), ""));
-    testServerToClient(new RpcResponse(12345, new byte[0]));
-    testServerToClient(new RpcResponse(12345, new byte[1000]));
+    testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0)));
+    testServerToClient(new RpcResponse(12345, new TestManagedBuffer(100)));
     testServerToClient(new RpcFailure(0, "this is an error"));
     testServerToClient(new RpcFailure(0, ""));
     // Note: buffer size must be "0" since StreamResponse's buffer is written differently to the

http://git-wip-us.apache.org/repos/asf/spark/blob/9bf21206/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
----------------------------------------------------------------------
diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
index 42955ef..f9b5bf9 100644
--- a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -31,6 +31,7 @@ import org.apache.spark.network.server.TransportServer;
 import org.apache.spark.network.util.MapConfigProvider;
 import org.apache.spark.network.util.TransportConf;
 import org.junit.*;
+import static org.junit.Assert.*;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -84,13 +85,16 @@ public class RequestTimeoutIntegrationSuite {
   @Test
   public void timeoutInactiveRequests() throws Exception {
     final Semaphore semaphore = new Semaphore(1);
-    final byte[] response = new byte[16];
+    final int responseSize = 16;
     RpcHandler handler = new RpcHandler() {
       @Override
-      public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
         try {
           semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
-          callback.onSuccess(response);
+          callback.onSuccess(ByteBuffer.allocate(responseSize));
         } catch (InterruptedException e) {
           // do nothing
         }
@@ -110,15 +114,15 @@ public class RequestTimeoutIntegrationSuite {
     // First completes quickly (semaphore starts at 1).
     TestCallback callback0 = new TestCallback();
     synchronized (callback0) {
-      client.sendRpc(new byte[0], callback0);
+      client.sendRpc(ByteBuffer.allocate(0), callback0);
       callback0.wait(FOREVER);
-      assert (callback0.success.length == response.length);
+      assertEquals(responseSize, callback0.successLength);
     }
 
     // Second times out after 2 seconds, with slack. Must be IOException.
     TestCallback callback1 = new TestCallback();
     synchronized (callback1) {
-      client.sendRpc(new byte[0], callback1);
+      client.sendRpc(ByteBuffer.allocate(0), callback1);
       callback1.wait(4 * 1000);
       assert (callback1.failure != null);
       assert (callback1.failure instanceof IOException);
@@ -131,13 +135,16 @@ public class RequestTimeoutIntegrationSuite {
   @Test
   public void timeoutCleanlyClosesClient() throws Exception {
     final Semaphore semaphore = new Semaphore(0);
-    final byte[] response = new byte[16];
+    final int responseSize = 16;
     RpcHandler handler = new RpcHandler() {
       @Override
-      public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
         try {
           semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
-          callback.onSuccess(response);
+          callback.onSuccess(ByteBuffer.allocate(responseSize));
         } catch (InterruptedException e) {
           // do nothing
         }
@@ -158,7 +165,7 @@ public class RequestTimeoutIntegrationSuite {
       clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
     TestCallback callback0 = new TestCallback();
     synchronized (callback0) {
-      client0.sendRpc(new byte[0], callback0);
+      client0.sendRpc(ByteBuffer.allocate(0), callback0);
       callback0.wait(FOREVER);
       assert (callback0.failure instanceof IOException);
       assert (!client0.isActive());
@@ -170,10 +177,10 @@ public class RequestTimeoutIntegrationSuite {
       clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
     TestCallback callback1 = new TestCallback();
     synchronized (callback1) {
-      client1.sendRpc(new byte[0], callback1);
+      client1.sendRpc(ByteBuffer.allocate(0), callback1);
       callback1.wait(FOREVER);
-      assert (callback1.success.length == response.length);
-      assert (callback1.failure == null);
+      assertEquals(responseSize, callback1.successLength);
+      assertNull(callback1.failure);
     }
   }
 
@@ -191,7 +198,10 @@ public class RequestTimeoutIntegrationSuite {
     };
     RpcHandler handler = new RpcHandler() {
       @Override
-      public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
         throw new UnsupportedOperationException();
       }
 
@@ -218,9 +228,10 @@ public class RequestTimeoutIntegrationSuite {
 
     synchronized (callback0) {
       // not complete yet, but should complete soon
-      assert (callback0.success == null && callback0.failure == null);
+      assertEquals(-1, callback0.successLength);
+      assertNull(callback0.failure);
       callback0.wait(2 * 1000);
-      assert (callback0.failure instanceof IOException);
+      assertTrue(callback0.failure instanceof IOException);
     }
 
     synchronized (callback1) {
@@ -235,13 +246,13 @@ public class RequestTimeoutIntegrationSuite {
    */
   class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
 
-    byte[] success;
+    int successLength = -1;
     Throwable failure;
 
     @Override
-    public void onSuccess(byte[] response) {
+    public void onSuccess(ByteBuffer response) {
       synchronized(this) {
-        success = response;
+        successLength = response.remaining();
         this.notifyAll();
       }
     }
@@ -258,7 +269,7 @@ public class RequestTimeoutIntegrationSuite {
     public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
       synchronized(this) {
         try {
-          success = buffer.nioByteBuffer().array();
+          successLength = buffer.nioByteBuffer().remaining();
           this.notifyAll();
         } catch (IOException e) {
           // weird


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org