Skip to content
Snippets Groups Projects
Commit 264b0f36 authored by Marcelo Vanzin's avatar Marcelo Vanzin
Browse files

[SPARK-21408][CORE] Better default number of RPC dispatch threads.

Instead of using the host's cpu count, use the number of cores allocated
for the Spark process when sizing the RPC dispatch thread pool. This avoids
creating large thread pools on large machines when the number of allocated
cores is small.

Tested by verifying number of threads with spark.executor.cores set
to 1 and 4; same thing for YARN AM.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #18639 from vanzin/SPARK-21408.
parent cde64add
No related branches found
No related tags found
No related merge requests found
...@@ -243,7 +243,7 @@ object SparkEnv extends Logging { ...@@ -243,7 +243,7 @@ object SparkEnv extends Logging {
val systemName = if (isDriver) driverSystemName else executorSystemName val systemName = if (isDriver) driverSystemName else executorSystemName
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf,
securityManager, clientMode = !isDriver) securityManager, numUsableCores, !isDriver)
// Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied.
if (isDriver) { if (isDriver) {
......
...@@ -40,7 +40,7 @@ private[spark] object RpcEnv { ...@@ -40,7 +40,7 @@ private[spark] object RpcEnv {
conf: SparkConf, conf: SparkConf,
securityManager: SecurityManager, securityManager: SecurityManager,
clientMode: Boolean = false): RpcEnv = { clientMode: Boolean = false): RpcEnv = {
create(name, host, host, port, conf, securityManager, clientMode) create(name, host, host, port, conf, securityManager, 0, clientMode)
} }
def create( def create(
...@@ -50,9 +50,10 @@ private[spark] object RpcEnv { ...@@ -50,9 +50,10 @@ private[spark] object RpcEnv {
port: Int, port: Int,
conf: SparkConf, conf: SparkConf,
securityManager: SecurityManager, securityManager: SecurityManager,
numUsableCores: Int,
clientMode: Boolean): RpcEnv = { clientMode: Boolean): RpcEnv = {
val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager, val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
clientMode) numUsableCores, clientMode)
new NettyRpcEnvFactory().create(config) new NettyRpcEnvFactory().create(config)
} }
} }
...@@ -201,4 +202,5 @@ private[spark] case class RpcEnvConfig( ...@@ -201,4 +202,5 @@ private[spark] case class RpcEnvConfig(
advertiseAddress: String, advertiseAddress: String,
port: Int, port: Int,
securityManager: SecurityManager, securityManager: SecurityManager,
numUsableCores: Int,
clientMode: Boolean) clientMode: Boolean)
...@@ -32,8 +32,11 @@ import org.apache.spark.util.ThreadUtils ...@@ -32,8 +32,11 @@ import org.apache.spark.util.ThreadUtils
/** /**
* A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s).
*
* @param numUsableCores Number of CPU cores allocated to the process, for sizing the thread pool.
* If 0, will consider the available CPUs on the host.
*/ */
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging {
private class EndpointData( private class EndpointData(
val name: String, val name: String,
...@@ -192,8 +195,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { ...@@ -192,8 +195,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
/** Thread pool used for dispatching messages. */ /** Thread pool used for dispatching messages. */
private val threadpool: ThreadPoolExecutor = { private val threadpool: ThreadPoolExecutor = {
val availableCores =
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
math.max(2, Runtime.getRuntime.availableProcessors())) math.max(2, availableCores))
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) { for (i <- 0 until numThreads) {
pool.execute(new MessageLoop) pool.execute(new MessageLoop)
......
...@@ -44,14 +44,15 @@ private[netty] class NettyRpcEnv( ...@@ -44,14 +44,15 @@ private[netty] class NettyRpcEnv(
val conf: SparkConf, val conf: SparkConf,
javaSerializerInstance: JavaSerializerInstance, javaSerializerInstance: JavaSerializerInstance,
host: String, host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging { securityManager: SecurityManager,
numUsableCores: Int) extends RpcEnv(conf) with Logging {
private[netty] val transportConf = SparkTransportConf.fromSparkConf( private[netty] val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"),
"rpc", "rpc",
conf.getInt("spark.rpc.io.threads", 0)) conf.getInt("spark.rpc.io.threads", 0))
private val dispatcher: Dispatcher = new Dispatcher(this) private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores)
private val streamManager = new NettyStreamManager(this) private val streamManager = new NettyStreamManager(this)
...@@ -451,7 +452,7 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { ...@@ -451,7 +452,7 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
val nettyEnv = val nettyEnv =
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress, new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
config.securityManager) config.securityManager, config.numUsableCores)
if (!config.clientMode) { if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
nettyEnv.startServer(config.bindAddress, actualPort) nettyEnv.startServer(config.bindAddress, actualPort)
......
...@@ -31,7 +31,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { ...@@ -31,7 +31,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
port: Int, port: Int,
clientMode: Boolean = false): RpcEnv = { clientMode: Boolean = false): RpcEnv = {
val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port, val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port,
new SecurityManager(conf), clientMode) new SecurityManager(conf), 0, clientMode)
new NettyRpcEnvFactory().create(config) new NettyRpcEnvFactory().create(config)
} }
...@@ -47,7 +47,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { ...@@ -47,7 +47,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar {
test("advertise address different from bind address") { test("advertise address different from bind address") {
val sparkConf = new SparkConf() val sparkConf = new SparkConf()
val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0, val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0,
new SecurityManager(sparkConf), false) new SecurityManager(sparkConf), 0, false)
val env = new NettyRpcEnvFactory().create(config) val env = new NettyRpcEnvFactory().create(config)
try { try {
assert(env.address.hostPort.startsWith("example.com:")) assert(env.address.hostPort.startsWith("example.com:"))
......
...@@ -459,8 +459,10 @@ private[spark] class ApplicationMaster( ...@@ -459,8 +459,10 @@ private[spark] class ApplicationMaster(
} }
private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { private def runExecutorLauncher(securityMgr: SecurityManager): Unit = {
rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, -1, sparkConf, securityMgr, val hostname = Utils.localHostName
clientMode = true) val amCores = sparkConf.get(AM_CORES)
rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr,
amCores, true)
val driverRef = waitForSparkDriver() val driverRef = waitForSparkDriver()
addAmIpFilter() addAmIpFilter()
registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"), registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment