You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by zs...@apache.org on 2017/01/27 23:08:00 UTC

spark git commit: [SPARK-19365][CORE] Optimize RequestMessage serialization

Repository: spark
Updated Branches:
  refs/heads/master a7ab6f9a8 -> 21aa8c32b


[SPARK-19365][CORE] Optimize RequestMessage serialization

## What changes were proposed in this pull request?

Right now Netty PRC serializes `RequestMessage` using Java serialization, and the size of a single message (e.g., RequestMessage(..., "hello")`) is almost 1KB.

This PR optimizes it by serializing `RequestMessage` manually (eliminate unnecessary information from most messages, e.g., class names of `RequestMessage`, `NettyRpcEndpointRef`, ...), and reduces the above message size to 100+ bytes.

## How was this patch tested?

Jenkins

I did a simple test to measure the improvement:

Before
```
$ bin/spark-shell --master local-cluster[1,4,1024]
...
scala> for (i <- 1 to 10) {
     |   val start = System.nanoTime
     |   val s = sc.parallelize(1 to 1000000, 10 * 1000).count()
     |   val end = System.nanoTime
     |   println(s"$i\t" + ((end - start)/1000/1000))
     | }
1       6830
2       4353
3       3322
4       3107
5       3235
6       3139
7       3156
8       3166
9       3091
10      3029
```
After:
```
$ bin/spark-shell --master local-cluster[1,4,1024]
...
scala> for (i <- 1 to 10) {
     |   val start = System.nanoTime
     |   val s = sc.parallelize(1 to 1000000, 10 * 1000).count()
     |   val end = System.nanoTime
     |   println(s"$i\t" + ((end - start)/1000/1000))
     | }
1       6431
2       3643
3       2913
4       2679
5       2760
6       2710
7       2747
8       2793
9       2679
10      2651
```

I also captured the TCP packets for this test. Before this patch, the total size of TCP packets is ~1.5GB. After it, it reduces to ~1.2GB.

Author: Shixiong Zhu <sh...@databricks.com>

Closes #16706 from zsxwing/rpc-opt.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21aa8c32
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21aa8c32
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21aa8c32

Branch: refs/heads/master
Commit: 21aa8c32ba7a29aafc000ecce2e6c802ced6a009
Parents: a7ab6f9
Author: Shixiong Zhu <sh...@databricks.com>
Authored: Fri Jan 27 15:07:57 2017 -0800
Committer: Shixiong Zhu <sh...@databricks.com>
Committed: Fri Jan 27 15:07:57 2017 -0800

----------------------------------------------------------------------
 .../apache/spark/rpc/RpcEndpointAddress.scala   |   5 +-
 .../apache/spark/rpc/netty/NettyRpcEnv.scala    | 119 +++++++++++++++----
 .../spark/rpc/netty/NettyRpcEnvSuite.scala      |  33 ++++-
 .../spark/rpc/netty/NettyRpcHandlerSuite.scala  |   2 +-
 4 files changed, 132 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/21aa8c32/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
index b9db60a..fdbccc9 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala
@@ -25,10 +25,11 @@ import org.apache.spark.SparkException
  * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only
  * connection and can only be reached via the client that sent the endpoint reference.
  *
- * @param rpcAddress The socket address of the endpoint.
+ * @param rpcAddress The socket address of the endpoint. It's `null` when this address pointing to
+ *                   an endpoint in a client `NettyRpcEnv`.
  * @param name Name of the endpoint.
  */
-private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) {
+private[spark] case class RpcEndpointAddress(rpcAddress: RpcAddress, name: String) {
 
   require(name != null, "RpcEndpoint name must be provided.")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/21aa8c32/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 1e448b2..ff5e39a 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -37,8 +37,8 @@ import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.network.server._
 import org.apache.spark.rpc._
-import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance}
-import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream}
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils}
 
 private[netty] class NettyRpcEnv(
     val conf: SparkConf,
@@ -189,7 +189,7 @@ private[netty] class NettyRpcEnv(
       }
     } else {
       // Message to a remote RPC endpoint.
-      postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
+      postToOutbox(message.receiver, OneWayOutboxMessage(message.serialize(this)))
     }
   }
 
@@ -224,7 +224,7 @@ private[netty] class NettyRpcEnv(
         }(ThreadUtils.sameThread)
         dispatcher.postLocalMessage(message, p)
       } else {
-        val rpcMessage = RpcOutboxMessage(serialize(message),
+        val rpcMessage = RpcOutboxMessage(message.serialize(this),
           onFailure,
           (client, response) => onSuccess(deserialize[Any](client, response)))
         postToOutbox(message.receiver, rpcMessage)
@@ -253,6 +253,13 @@ private[netty] class NettyRpcEnv(
     javaSerializerInstance.serialize(content)
   }
 
+  /**
+   * Returns [[SerializationStream]] that forwards the serialized bytes to `out`.
+   */
+  private[netty] def serializeStream(out: OutputStream): SerializationStream = {
+    javaSerializerInstance.serializeStream(out)
+  }
+
   private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
     NettyRpcEnv.currentClient.withValue(client) {
       deserialize { () =>
@@ -480,16 +487,13 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
  */
 private[netty] class NettyRpcEndpointRef(
     @transient private val conf: SparkConf,
-    endpointAddress: RpcEndpointAddress,
-    @transient @volatile private var nettyEnv: NettyRpcEnv)
-  extends RpcEndpointRef(conf) with Serializable with Logging {
+    private val endpointAddress: RpcEndpointAddress,
+    @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
 
   @transient @volatile var client: TransportClient = _
 
-  private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
-  private val _name = endpointAddress.name
-
-  override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
+  override def address: RpcAddress =
+    if (endpointAddress.rpcAddress != null) endpointAddress.rpcAddress else null
 
   private def readObject(in: ObjectInputStream): Unit = {
     in.defaultReadObject()
@@ -501,34 +505,103 @@ private[netty] class NettyRpcEndpointRef(
     out.defaultWriteObject()
   }
 
-  override def name: String = _name
+  override def name: String = endpointAddress.name
 
   override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
-    nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)
+    nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
   }
 
   override def send(message: Any): Unit = {
     require(message != null, "Message is null")
-    nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
+    nettyEnv.send(new RequestMessage(nettyEnv.address, this, message))
   }
 
-  override def toString: String = s"NettyRpcEndpointRef(${_address})"
-
-  def toURI: URI = new URI(_address.toString)
+  override def toString: String = s"NettyRpcEndpointRef(${endpointAddress})"
 
   final override def equals(that: Any): Boolean = that match {
-    case other: NettyRpcEndpointRef => _address == other._address
+    case other: NettyRpcEndpointRef => endpointAddress == other.endpointAddress
     case _ => false
   }
 
-  final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()
+  final override def hashCode(): Int =
+    if (endpointAddress == null) 0 else endpointAddress.hashCode()
 }
 
 /**
  * The message that is sent from the sender to the receiver.
+ *
+ * @param senderAddress the sender address. It's `null` if this message is from a client
+ *                      `NettyRpcEnv`.
+ * @param receiver the receiver of this message.
+ * @param content the message content.
  */
-private[netty] case class RequestMessage(
-    senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any)
+private[netty] class RequestMessage(
+    val senderAddress: RpcAddress,
+    val receiver: NettyRpcEndpointRef,
+    val content: Any) {
+
+  /** Manually serialize [[RequestMessage]] to minimize the size. */
+  def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = {
+    val bos = new ByteBufferOutputStream()
+    val out = new DataOutputStream(bos)
+    try {
+      writeRpcAddress(out, senderAddress)
+      writeRpcAddress(out, receiver.address)
+      out.writeUTF(receiver.name)
+      val s = nettyEnv.serializeStream(out)
+      try {
+        s.writeObject(content)
+      } finally {
+        s.close()
+      }
+    } finally {
+      out.close()
+    }
+    bos.toByteBuffer
+  }
+
+  private def writeRpcAddress(out: DataOutputStream, rpcAddress: RpcAddress): Unit = {
+    if (rpcAddress == null) {
+      out.writeBoolean(false)
+    } else {
+      out.writeBoolean(true)
+      out.writeUTF(rpcAddress.host)
+      out.writeInt(rpcAddress.port)
+    }
+  }
+
+  override def toString: String = s"RequestMessage($senderAddress, $receiver, $content)"
+}
+
+private[netty] object RequestMessage {
+
+  private def readRpcAddress(in: DataInputStream): RpcAddress = {
+    val hasRpcAddress = in.readBoolean()
+    if (hasRpcAddress) {
+      RpcAddress(in.readUTF(), in.readInt())
+    } else {
+      null
+    }
+  }
+
+  def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = {
+    val bis = new ByteBufferInputStream(bytes)
+    val in = new DataInputStream(bis)
+    try {
+      val senderAddress = readRpcAddress(in)
+      val endpointAddress = RpcEndpointAddress(readRpcAddress(in), in.readUTF())
+      val ref = new NettyRpcEndpointRef(nettyEnv.conf, endpointAddress, nettyEnv)
+      ref.client = client
+      new RequestMessage(
+        senderAddress,
+        ref,
+        // The remaining bytes in `bytes` are the message content.
+        nettyEnv.deserialize(client, bytes))
+    } finally {
+      in.close()
+    }
+  }
+}
 
 /**
  * A response that indicates some failure happens in the receiver side.
@@ -574,10 +647,10 @@ private[netty] class NettyRpcHandler(
     val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
     assert(addr != null)
     val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
-    val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
+    val requestMessage = RequestMessage(nettyEnv, client, message)
     if (requestMessage.senderAddress == null) {
       // Create a new message with the socket address of the client as the sender.
-      RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
+      new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
     } else {
       // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for
       // the listening address

http://git-wip-us.apache.org/repos/asf/spark/blob/21aa8c32/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
index 0409aa3..2b1bce4 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.rpc.netty
 
+import org.scalatest.mock.MockitoSugar
+
 import org.apache.spark._
+import org.apache.spark.network.client.TransportClient
 import org.apache.spark.rpc._
 
-class NettyRpcEnvSuite extends RpcEnvSuite {
+class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
 
   override def createRpcEnv(
       conf: SparkConf,
@@ -53,4 +56,32 @@ class NettyRpcEnvSuite extends RpcEnvSuite {
     }
   }
 
+  test("RequestMessage serialization") {
+    def assertRequestMessageEquals(expected: RequestMessage, actual: RequestMessage): Unit = {
+      assert(expected.senderAddress === actual.senderAddress)
+      assert(expected.receiver === actual.receiver)
+      assert(expected.content === actual.content)
+    }
+
+    val nettyEnv = env.asInstanceOf[NettyRpcEnv]
+    val client = mock[TransportClient]
+    val senderAddress = RpcAddress("locahost", 12345)
+    val receiverAddress = RpcEndpointAddress("localhost", 54321, "test")
+    val receiver = new NettyRpcEndpointRef(nettyEnv.conf, receiverAddress, nettyEnv)
+
+    val msg = new RequestMessage(senderAddress, receiver, "foo")
+    assertRequestMessageEquals(
+      msg,
+      RequestMessage(nettyEnv, client, msg.serialize(nettyEnv)))
+
+    val msg2 = new RequestMessage(null, receiver, "foo")
+    assertRequestMessageEquals(
+      msg2,
+      RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv)))
+
+    val msg3 = new RequestMessage(senderAddress, receiver, null)
+    assertRequestMessageEquals(
+      msg3,
+      RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv)))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/21aa8c32/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index 0c156fe..a71d872 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -34,7 +34,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
   val env = mock(classOf[NettyRpcEnv])
   val sm = mock(classOf[StreamManager])
   when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any()))
-    .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))
+    .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null))
 
   test("receive") {
     val dispatcher = mock(classOf[Dispatcher])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org