diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 45ed9860ea8f993682cde0f797f9b68c635abfa8..24928150315e884f7efeedb591b498a33aee71b8 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -243,7 +243,7 @@ object SparkEnv extends Logging { val systemName = if (isDriver) driverSystemName else executorSystemName val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, - securityManager, clientMode = !isDriver) + securityManager, numUsableCores, !isDriver) // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. if (isDriver) { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 530743c03640b3e4bd0c1d8f7dd949bef7e8fb91..de2cc56bc6b1667545db0d3a2d87bbda9ebb86b7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -40,7 +40,7 @@ private[spark] object RpcEnv { conf: SparkConf, securityManager: SecurityManager, clientMode: Boolean = false): RpcEnv = { - create(name, host, host, port, conf, securityManager, clientMode) + create(name, host, host, port, conf, securityManager, 0, clientMode) } def create( @@ -50,9 +50,10 @@ private[spark] object RpcEnv { port: Int, conf: SparkConf, securityManager: SecurityManager, + numUsableCores: Int, clientMode: Boolean): RpcEnv = { val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager, - clientMode) + numUsableCores, clientMode) new NettyRpcEnvFactory().create(config) } } @@ -201,4 +202,5 @@ private[spark] case class RpcEnvConfig( advertiseAddress: String, port: Int, securityManager: SecurityManager, + numUsableCores: Int, clientMode: Boolean) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index e94babb846128df28c0bcfb1b3ed5fe03f36b87e..904c4d02dd2a46b253c021a79c3da6f974202e0c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -32,8 +32,11 @@ import org.apache.spark.util.ThreadUtils /** * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). + * + * @param numUsableCores Number of CPU cores allocated to the process, for sizing the thread pool. + * If 0, will consider the available CPUs on the host. */ -private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { +private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging { private class EndpointData( val name: String, @@ -192,8 +195,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { /** Thread pool used for dispatching messages. */ private val threadpool: ThreadPoolExecutor = { + val availableCores = + if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads", - math.max(2, Runtime.getRuntime.availableProcessors())) + math.max(2, availableCores)) val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") for (i <- 0 until numThreads) { pool.execute(new MessageLoop) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 64898499246ac035a2ac97aadfe9636a78f5c4d2..1777e7a539751561f838d3d9de59e0b99b3840e7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -44,14 +44,15 @@ private[netty] class NettyRpcEnv( val conf: SparkConf, javaSerializerInstance: JavaSerializerInstance, host: String, - securityManager: SecurityManager) extends RpcEnv(conf) with Logging { + securityManager: SecurityManager, + numUsableCores: Int) extends RpcEnv(conf) with Logging { private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0)) - private val dispatcher: Dispatcher = new Dispatcher(this) + private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) private val streamManager = new NettyStreamManager(this) @@ -451,7 +452,7 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] val nettyEnv = new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress, - config.securityManager) + config.securityManager, config.numUsableCores) if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(config.bindAddress, actualPort) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index 2b1bce4d208f688d3f2a5d0173b6ef36f664ee63..777163709bbf527398e328598f5da11afb791a5a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -31,7 +31,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { port: Int, clientMode: Boolean = false): RpcEnv = { val config = RpcEnvConfig(conf, "test", "localhost", "localhost", port, - new SecurityManager(conf), clientMode) + new SecurityManager(conf), 0, clientMode) new NettyRpcEnvFactory().create(config) } @@ -47,7 +47,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { test("advertise address different from bind address") { val sparkConf = new SparkConf() val config = RpcEnvConfig(sparkConf, "test", "localhost", "example.com", 0, - new SecurityManager(sparkConf), false) + new SecurityManager(sparkConf), 0, false) val env = new NettyRpcEnvFactory().create(config) try { assert(env.address.hostPort.startsWith("example.com:")) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 6ff210adcbc55130655ea0db287d24e221d01efc..fc925022b2718d12f39f06cb3eacbd4b4056c211 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -459,8 +459,10 @@ private[spark] class ApplicationMaster( } private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { - rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, -1, sparkConf, securityMgr, - clientMode = true) + val hostname = Utils.localHostName + val amCores = sparkConf.get(AM_CORES) + rpcEnv = RpcEnv.create("sparkYarnAM", hostname, hostname, -1, sparkConf, securityMgr, + amCores, true) val driverRef = waitForSparkDriver() addAmIpFilter() registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"),