Skip to content
Snippets Groups Projects
Commit 21aa8c32 authored by Shixiong Zhu's avatar Shixiong Zhu
Browse files

[SPARK-19365][CORE] Optimize RequestMessage serialization

## What changes were proposed in this pull request?

Right now Netty PRC serializes `RequestMessage` using Java serialization, and the size of a single message (e.g., RequestMessage(..., "hello")`) is almost 1KB.

This PR optimizes it by serializing `RequestMessage` manually (eliminate unnecessary information from most messages, e.g., class names of `RequestMessage`, `NettyRpcEndpointRef`, ...), and reduces the above message size to 100+ bytes.

## How was this patch tested?

Jenkins

I did a simple test to measure the improvement:

Before
```
$ bin/spark-shell --master local-cluster[1,4,1024]
...
scala> for (i <- 1 to 10) {
     |   val start = System.nanoTime
     |   val s = sc.parallelize(1 to 1000000, 10 * 1000).count()
     |   val end = System.nanoTime
     |   println(s"$i\t" + ((end - start)/1000/1000))
     | }
1       6830
2       4353
3       3322
4       3107
5       3235
6       3139
7       3156
8       3166
9       3091
10      3029
```
After:
```
$ bin/spark-shell --master local-cluster[1,4,1024]
...
scala> for (i <- 1 to 10) {
     |   val start = System.nanoTime
     |   val s = sc.parallelize(1 to 1000000, 10 * 1000).count()
     |   val end = System.nanoTime
     |   println(s"$i\t" + ((end - start)/1000/1000))
     | }
1       6431
2       3643
3       2913
4       2679
5       2760
6       2710
7       2747
8       2793
9       2679
10      2651
```

I also captured the TCP packets for this test. Before this patch, the total size of TCP packets is ~1.5GB. After it, it reduces to ~1.2GB.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16706 from zsxwing/rpc-opt.
parent a7ab6f9a
No related branches found
No related tags found
No related merge requests found
...@@ -25,10 +25,11 @@ import org.apache.spark.SparkException ...@@ -25,10 +25,11 @@ import org.apache.spark.SparkException
* The `rpcAddress` may be null, in which case the endpoint is registered via a client-only * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only
* connection and can only be reached via the client that sent the endpoint reference. * connection and can only be reached via the client that sent the endpoint reference.
* *
* @param rpcAddress The socket address of the endpoint. * @param rpcAddress The socket address of the endpoint. It's `null` when this address pointing to
* an endpoint in a client `NettyRpcEnv`.
* @param name Name of the endpoint. * @param name Name of the endpoint.
*/ */
private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { private[spark] case class RpcEndpointAddress(rpcAddress: RpcAddress, name: String) {
require(name != null, "RpcEndpoint name must be provided.") require(name != null, "RpcEndpoint name must be provided.")
......
...@@ -37,8 +37,8 @@ import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap ...@@ -37,8 +37,8 @@ import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap
import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.server._ import org.apache.spark.network.server._
import org.apache.spark.rpc._ import org.apache.spark.rpc._
import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance} import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream}
import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils}
private[netty] class NettyRpcEnv( private[netty] class NettyRpcEnv(
val conf: SparkConf, val conf: SparkConf,
...@@ -189,7 +189,7 @@ private[netty] class NettyRpcEnv( ...@@ -189,7 +189,7 @@ private[netty] class NettyRpcEnv(
} }
} else { } else {
// Message to a remote RPC endpoint. // Message to a remote RPC endpoint.
postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message))) postToOutbox(message.receiver, OneWayOutboxMessage(message.serialize(this)))
} }
} }
...@@ -224,7 +224,7 @@ private[netty] class NettyRpcEnv( ...@@ -224,7 +224,7 @@ private[netty] class NettyRpcEnv(
}(ThreadUtils.sameThread) }(ThreadUtils.sameThread)
dispatcher.postLocalMessage(message, p) dispatcher.postLocalMessage(message, p)
} else { } else {
val rpcMessage = RpcOutboxMessage(serialize(message), val rpcMessage = RpcOutboxMessage(message.serialize(this),
onFailure, onFailure,
(client, response) => onSuccess(deserialize[Any](client, response))) (client, response) => onSuccess(deserialize[Any](client, response)))
postToOutbox(message.receiver, rpcMessage) postToOutbox(message.receiver, rpcMessage)
...@@ -253,6 +253,13 @@ private[netty] class NettyRpcEnv( ...@@ -253,6 +253,13 @@ private[netty] class NettyRpcEnv(
javaSerializerInstance.serialize(content) javaSerializerInstance.serialize(content)
} }
/**
* Returns [[SerializationStream]] that forwards the serialized bytes to `out`.
*/
private[netty] def serializeStream(out: OutputStream): SerializationStream = {
javaSerializerInstance.serializeStream(out)
}
private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
NettyRpcEnv.currentClient.withValue(client) { NettyRpcEnv.currentClient.withValue(client) {
deserialize { () => deserialize { () =>
...@@ -480,16 +487,13 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { ...@@ -480,16 +487,13 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
*/ */
private[netty] class NettyRpcEndpointRef( private[netty] class NettyRpcEndpointRef(
@transient private val conf: SparkConf, @transient private val conf: SparkConf,
endpointAddress: RpcEndpointAddress, private val endpointAddress: RpcEndpointAddress,
@transient @volatile private var nettyEnv: NettyRpcEnv) @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
extends RpcEndpointRef(conf) with Serializable with Logging {
@transient @volatile var client: TransportClient = _ @transient @volatile var client: TransportClient = _
private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null override def address: RpcAddress =
private val _name = endpointAddress.name if (endpointAddress.rpcAddress != null) endpointAddress.rpcAddress else null
override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
private def readObject(in: ObjectInputStream): Unit = { private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject() in.defaultReadObject()
...@@ -501,34 +505,103 @@ private[netty] class NettyRpcEndpointRef( ...@@ -501,34 +505,103 @@ private[netty] class NettyRpcEndpointRef(
out.defaultWriteObject() out.defaultWriteObject()
} }
override def name: String = _name override def name: String = endpointAddress.name
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout) nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
} }
override def send(message: Any): Unit = { override def send(message: Any): Unit = {
require(message != null, "Message is null") require(message != null, "Message is null")
nettyEnv.send(RequestMessage(nettyEnv.address, this, message)) nettyEnv.send(new RequestMessage(nettyEnv.address, this, message))
} }
override def toString: String = s"NettyRpcEndpointRef(${_address})" override def toString: String = s"NettyRpcEndpointRef(${endpointAddress})"
def toURI: URI = new URI(_address.toString)
final override def equals(that: Any): Boolean = that match { final override def equals(that: Any): Boolean = that match {
case other: NettyRpcEndpointRef => _address == other._address case other: NettyRpcEndpointRef => endpointAddress == other.endpointAddress
case _ => false case _ => false
} }
final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode() final override def hashCode(): Int =
if (endpointAddress == null) 0 else endpointAddress.hashCode()
} }
/** /**
* The message that is sent from the sender to the receiver. * The message that is sent from the sender to the receiver.
*
* @param senderAddress the sender address. It's `null` if this message is from a client
* `NettyRpcEnv`.
* @param receiver the receiver of this message.
* @param content the message content.
*/ */
private[netty] case class RequestMessage( private[netty] class RequestMessage(
senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any) val senderAddress: RpcAddress,
val receiver: NettyRpcEndpointRef,
val content: Any) {
/** Manually serialize [[RequestMessage]] to minimize the size. */
def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = {
val bos = new ByteBufferOutputStream()
val out = new DataOutputStream(bos)
try {
writeRpcAddress(out, senderAddress)
writeRpcAddress(out, receiver.address)
out.writeUTF(receiver.name)
val s = nettyEnv.serializeStream(out)
try {
s.writeObject(content)
} finally {
s.close()
}
} finally {
out.close()
}
bos.toByteBuffer
}
private def writeRpcAddress(out: DataOutputStream, rpcAddress: RpcAddress): Unit = {
if (rpcAddress == null) {
out.writeBoolean(false)
} else {
out.writeBoolean(true)
out.writeUTF(rpcAddress.host)
out.writeInt(rpcAddress.port)
}
}
override def toString: String = s"RequestMessage($senderAddress, $receiver, $content)"
}
private[netty] object RequestMessage {
private def readRpcAddress(in: DataInputStream): RpcAddress = {
val hasRpcAddress = in.readBoolean()
if (hasRpcAddress) {
RpcAddress(in.readUTF(), in.readInt())
} else {
null
}
}
def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = {
val bis = new ByteBufferInputStream(bytes)
val in = new DataInputStream(bis)
try {
val senderAddress = readRpcAddress(in)
val endpointAddress = RpcEndpointAddress(readRpcAddress(in), in.readUTF())
val ref = new NettyRpcEndpointRef(nettyEnv.conf, endpointAddress, nettyEnv)
ref.client = client
new RequestMessage(
senderAddress,
ref,
// The remaining bytes in `bytes` are the message content.
nettyEnv.deserialize(client, bytes))
} finally {
in.close()
}
}
}
/** /**
* A response that indicates some failure happens in the receiver side. * A response that indicates some failure happens in the receiver side.
...@@ -574,10 +647,10 @@ private[netty] class NettyRpcHandler( ...@@ -574,10 +647,10 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null) assert(addr != null)
val clientAddr = RpcAddress(addr.getHostString, addr.getPort) val clientAddr = RpcAddress(addr.getHostString, addr.getPort)
val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) val requestMessage = RequestMessage(nettyEnv, client, message)
if (requestMessage.senderAddress == null) { if (requestMessage.senderAddress == null) {
// Create a new message with the socket address of the client as the sender. // Create a new message with the socket address of the client as the sender.
RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
} else { } else {
// The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for
// the listening address // the listening address
......
...@@ -17,10 +17,13 @@ ...@@ -17,10 +17,13 @@
package org.apache.spark.rpc.netty package org.apache.spark.rpc.netty
import org.scalatest.mock.MockitoSugar
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.network.client.TransportClient
import org.apache.spark.rpc._ import org.apache.spark.rpc._
class NettyRpcEnvSuite extends RpcEnvSuite { class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
override def createRpcEnv( override def createRpcEnv(
conf: SparkConf, conf: SparkConf,
...@@ -53,4 +56,32 @@ class NettyRpcEnvSuite extends RpcEnvSuite { ...@@ -53,4 +56,32 @@ class NettyRpcEnvSuite extends RpcEnvSuite {
} }
} }
test("RequestMessage serialization") {
def assertRequestMessageEquals(expected: RequestMessage, actual: RequestMessage): Unit = {
assert(expected.senderAddress === actual.senderAddress)
assert(expected.receiver === actual.receiver)
assert(expected.content === actual.content)
}
val nettyEnv = env.asInstanceOf[NettyRpcEnv]
val client = mock[TransportClient]
val senderAddress = RpcAddress("locahost", 12345)
val receiverAddress = RpcEndpointAddress("localhost", 54321, "test")
val receiver = new NettyRpcEndpointRef(nettyEnv.conf, receiverAddress, nettyEnv)
val msg = new RequestMessage(senderAddress, receiver, "foo")
assertRequestMessageEquals(
msg,
RequestMessage(nettyEnv, client, msg.serialize(nettyEnv)))
val msg2 = new RequestMessage(null, receiver, "foo")
assertRequestMessageEquals(
msg2,
RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv)))
val msg3 = new RequestMessage(senderAddress, receiver, null)
assertRequestMessageEquals(
msg3,
RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv)))
}
} }
...@@ -34,7 +34,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { ...@@ -34,7 +34,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
val env = mock(classOf[NettyRpcEnv]) val env = mock(classOf[NettyRpcEnv])
val sm = mock(classOf[StreamManager]) val sm = mock(classOf[StreamManager])
when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any()))
.thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null))
test("receive") { test("receive") {
val dispatcher = mock(classOf[Dispatcher]) val dispatcher = mock(classOf[Dispatcher])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment