Skip to content
Snippets Groups Projects
Commit cf2e0ae7 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-11096] Post-hoc review Netty based RPC implementation - round 2

A few more changes:

1. Renamed IDVerifier -> RpcEndpointVerifier
2. Renamed NettyRpcAddress -> RpcEndpointAddress
3. Simplified NettyRpcHandler a bit by removing the connection count tracking. This is OK because I now force spark.shuffle.io.numConnectionsPerPeer to 1
4. Reduced spark.rpc.connect.threads to 64. It would be great to eventually remove this extra thread pool.
5. Minor cleanup & documentation.

Author: Reynold Xin <rxin@databricks.com>

Closes #9112 from rxin/SPARK-11096.
parent 615cc858
No related branches found
No related tags found
No related merge requests found
......@@ -93,15 +93,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri))
}
/**
* Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`
* asynchronously.
*/
def asyncSetupEndpointRef(
systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = {
asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName))
}
/**
* Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`.
* This is a blocking action.
......
......@@ -29,6 +29,9 @@ import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc._
import org.apache.spark.util.ThreadUtils
/**
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
*/
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private class EndpointData(
......@@ -42,7 +45,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
// Track the receivers whose inboxes may contain messages.
private val receivers = new LinkedBlockingQueue[EndpointData]()
private val receivers = new LinkedBlockingQueue[EndpointData]
/**
* True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced
......@@ -52,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private var stopped = false
def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name)
val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name)
val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
synchronized {
if (stopped) {
......
......@@ -22,7 +22,6 @@ import java.nio.ByteBuffer
import java.util.concurrent._
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag
......@@ -45,8 +44,10 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
private val transportConf =
SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0))
// Override numConnectionsPerPeer to 1 for RPC.
private val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"),
conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this)
......@@ -54,14 +55,14 @@ private[netty] class NettyRpcEnv(
new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
private val clientFactory = {
val bootstraps: Seq[TransportClientBootstrap] =
val bootstraps: java.util.List[TransportClientBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
Seq(new SaslClientBootstrap(transportConf, "", securityManager,
java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,
securityManager.isSaslEncryptionEnabled()))
} else {
Nil
java.util.Collections.emptyList[TransportClientBootstrap]
}
transportContext.createClientFactory(bootstraps.asJava)
transportContext.createClientFactory(bootstraps)
}
val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")
......@@ -71,7 +72,7 @@ private[netty] class NettyRpcEnv(
// TODO: a non-blocking TransportClientFactory.createClient in future
private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
"netty-rpc-connection",
conf.getInt("spark.rpc.connect.threads", 256))
conf.getInt("spark.rpc.connect.threads", 64))
@volatile private var server: TransportServer = _
......@@ -83,7 +84,8 @@ private[netty] class NettyRpcEnv(
java.util.Collections.emptyList()
}
server = transportContext.createServer(port, bootstraps)
dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher))
dispatcher.registerRpcEndpoint(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
override lazy val address: RpcAddress = {
......@@ -96,11 +98,11 @@ private[netty] class NettyRpcEnv(
}
def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
val addr = NettyRpcAddress(uri)
val addr = RpcEndpointAddress(uri)
val endpointRef = new NettyRpcEndpointRef(conf, addr, this)
val idVerifierRef =
new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this)
idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find =>
val verifier = new NettyRpcEndpointRef(
conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this)
verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find =>
if (find) {
Future.successful(endpointRef)
} else {
......@@ -117,16 +119,18 @@ private[netty] class NettyRpcEnv(
private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
// Message to a local RPC endpoint.
val promise = Promise[Any]()
dispatcher.postLocalMessage(message, promise)
promise.future.onComplete {
case Success(response) =>
val ack = response.asInstanceOf[Ack]
logDebug(s"Receive ack from ${ack.sender}")
logTrace(s"Received ack from ${ack.sender}")
case Failure(e) =>
logError(s"Exception when sending $message", e)
}(ThreadUtils.sameThread)
} else {
// Message to a remote RPC endpoint.
try {
// `createClient` will block if it cannot find a known connection, so we should run it in
// clientConnectionExecutor
......@@ -204,11 +208,10 @@ private[netty] class NettyRpcEnv(
}
})
} catch {
case e: RejectedExecutionException => {
case e: RejectedExecutionException =>
if (!promise.tryFailure(e)) {
logWarning(s"Ignore failure", e)
}
}
}
}
promise.future
......@@ -231,7 +234,7 @@ private[netty] class NettyRpcEnv(
}
override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String =
new NettyRpcAddress(address.host, address.port, endpointName).toString
new RpcEndpointAddress(address.host, address.port, endpointName).toString
override def shutdown(): Unit = {
cleanup()
......@@ -310,9 +313,9 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
@transient @volatile private var nettyEnv: NettyRpcEnv = _
@transient @volatile private var _address: NettyRpcAddress = _
@transient @volatile private var _address: RpcEndpointAddress = _
def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) {
def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) {
this(conf)
this._address = _address
this.nettyEnv = nettyEnv
......@@ -322,7 +325,7 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf)
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
_address = in.readObject().asInstanceOf[NettyRpcAddress]
_address = in.readObject().asInstanceOf[RpcEndpointAddress]
nettyEnv = NettyRpcEnv.currentEnv.value
}
......@@ -406,49 +409,37 @@ private[netty] class NettyRpcHandler(
private type RemoteEnvAddress = RpcAddress
// Store all client addresses and their NettyRpcEnv addresses.
// TODO: Is this even necessary?
@GuardedBy("this")
private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]()
// Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection
// count because `TransportClientFactory.createClient` will create multiple connections
// (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection
// to send the message. See `TransportClientFactory.createClient` for more details.
@GuardedBy("this")
private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]()
override def receive(
client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = {
val requestMessage = nettyEnv.deserialize[RequestMessage](message)
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val remoteEnvAddress = requestMessage.senderAddress
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
val broadcastMessage: Option[RemoteProcessConnected] =
synchronized {
// If the first connection to a remote RpcEnv is found, we should broadcast "Associated"
if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
// clientAddr connects at the first time
val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
// Increase the connection number of remoteEnvAddress
remoteConnectionCount.put(remoteEnvAddress, count + 1)
if (count == 0) {
// This is the first connection, so fire "Associated"
Some(RemoteProcessConnected(remoteEnvAddress))
} else {
None
}
} else {
None
}
// TODO: Can we add connection callback (channel registered) to the underlying framework?
// A variable to track whether we should dispatch the RemoteProcessConnected message.
var dispatchRemoteProcessConnected = false
synchronized {
if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) {
// clientAddr connects at the first time, fire "RemoteProcessConnected"
dispatchRemoteProcessConnected = true
}
broadcastMessage.foreach(dispatcher.postToAll)
}
if (dispatchRemoteProcessConnected) {
dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
}
dispatcher.postRemoteMessage(requestMessage, callback)
}
override def getStreamManager: StreamManager = new OneForOneStreamManager
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
val broadcastMessage =
......@@ -469,34 +460,21 @@ private[netty] class NettyRpcHandler(
}
override def connectionTerminated(client: TransportClient): Unit = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
val broadcastMessage =
synchronized {
// If the last connection to a remote RpcEnv is terminated, we should broadcast
// "Disassociated"
remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
remoteAddresses -= clientAddr
val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0)
assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent")
if (count - 1 == 0) {
// We lost all clients, so clean up and fire "Disassociated"
remoteConnectionCount.remove(remoteEnvAddress)
Some(RemoteProcessDisconnected(remoteEnvAddress))
} else {
// Decrease the connection number of remoteEnvAddress
remoteConnectionCount.put(remoteEnvAddress, count - 1)
None
}
}
val messageOpt: Option[RemoteProcessDisconnected] =
synchronized {
remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
remoteAddresses -= clientAddr
Some(RemoteProcessDisconnected(remoteEnvAddress))
}
broadcastMessage.foreach(dispatcher.postToAll)
}
messageOpt.foreach(dispatcher.postToAll)
} else {
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
// we can ignore it since we don't fire "Associated".
// See java.net.Socket.getRemoteSocketAddress
}
}
}
......@@ -17,40 +17,44 @@
package org.apache.spark.rpc.netty
import java.net.URI
import org.apache.spark.SparkException
import org.apache.spark.rpc.RpcAddress
private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) {
/**
* An address identifier for an RPC endpoint.
*
* @param host host name of the remote process.
* @param port the port the remote RPC environment binds to.
* @param name name of the remote endpoint.
*/
private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) {
def toRpcAddress: RpcAddress = RpcAddress(host, port)
override val toString = s"spark://$name@$host:$port"
}
private[netty] object NettyRpcAddress {
private[netty] object RpcEndpointAddress {
def apply(sparkUrl: String): NettyRpcAddress = {
def apply(sparkUrl: String): RpcEndpointAddress = {
try {
val uri = new URI(sparkUrl)
val uri = new java.net.URI(sparkUrl)
val host = uri.getHost
val port = uri.getPort
val name = uri.getUserInfo
if (uri.getScheme != "spark" ||
host == null ||
port < 0 ||
name == null ||
(uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
uri.getFragment != null ||
uri.getQuery != null) {
host == null ||
port < 0 ||
name == null ||
(uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null
uri.getFragment != null ||
uri.getQuery != null) {
throw new SparkException("Invalid Spark URL: " + sparkUrl)
}
NettyRpcAddress(host, port, name)
RpcEndpointAddress(host, port, name)
} catch {
case e: java.net.URISyntaxException =>
throw new SparkException("Invalid Spark URL: " + sparkUrl, e)
}
}
}
......@@ -14,26 +14,27 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rpc.netty
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
/**
* A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists
*/
private[netty] case class ID(name: String)
/**
* An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]]
* An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists.
*
* This is used when setting up a remote endpoint reference.
*/
private[netty] class IDVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
extends RpcEndpoint {
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case ID(name) => context.reply(dispatcher.verify(name))
case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name))
}
}
private[netty] object IDVerifier {
val NAME = "id-verifier"
private[netty] object RpcEndpointVerifier {
val NAME = "endpoint-verifier"
/** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */
case class CheckExistence(name: String)
}
......@@ -22,7 +22,7 @@ import org.apache.spark.SparkFunSuite
class NettyRpcAddressSuite extends SparkFunSuite {
test("toString") {
val addr = NettyRpcAddress("localhost", 12345, "test")
val addr = RpcEndpointAddress("localhost", 12345, "test")
assert(addr.toString === "spark://test@localhost:12345")
}
......
......@@ -42,9 +42,6 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000))
nettyRpcHandler.receive(client, null, null)
when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001))
nettyRpcHandler.receive(client, null, null)
verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345)))
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment