Skip to content
Snippets Groups Projects
Commit 336cd341 authored by Reynold Xin's avatar Reynold Xin
Browse files

Small refactoring to pass SparkEnv into Executor rather than creating SparkEnv in Executor.

This consolidates some code path and makes constructor arguments simpler for a few classes.

Author: Reynold Xin <rxin@databricks.com>

Closes #3738 from rxin/sparkEnvDepRefactor and squashes the following commits:

82e02cc [Reynold Xin] Fixed couple bugs.
217062a [Reynold Xin] Code review feedback.
bd00af7 [Reynold Xin] Small refactoring to pass SparkEnv into Executor rather than creating SparkEnv in Executor.
parent 8e253ebb
No related branches found
No related tags found
No related merge requests found
......@@ -177,18 +177,18 @@ object SparkEnv extends Logging {
hostname: String,
port: Int,
numCores: Int,
isLocal: Boolean,
actorSystem: ActorSystem = null): SparkEnv = {
create(
isLocal: Boolean): SparkEnv = {
val env = create(
conf,
executorId,
hostname,
port,
isDriver = false,
isLocal = isLocal,
defaultActorSystem = actorSystem,
numUsableCores = numCores
)
SparkEnv.set(env)
env
}
/**
......@@ -202,7 +202,6 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean,
listenerBus: LiveListenerBus = null,
defaultActorSystem: ActorSystem = null,
numUsableCores: Int = 0): SparkEnv = {
// Listener bus is only used on the driver
......@@ -212,20 +211,17 @@ object SparkEnv extends Logging {
val securityManager = new SecurityManager(conf)
// If an existing actor system is already provided, use it.
// This is the case when an executor is launched in coarse-grained mode.
val (actorSystem, boundPort) =
Option(defaultActorSystem) match {
case Some(as) => (as, port)
case None =>
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)
}
// Create the ActorSystem for Akka and get the port it binds to.
val (actorSystem, boundPort) = {
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)
}
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
// This is so that we tell the executors the correct port to connect to.
if (isDriver) {
conf.set("spark.driver.port", boundPort.toString)
} else {
conf.set("spark.executor.port", boundPort.toString)
}
// Create an instance of the class with the given name, possibly initializing it with our conf
......
......@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import scala.concurrent.Await
import akka.actor.{Actor, ActorSelection, ActorSystem, Props}
import akka.actor.{Actor, ActorSelection, Props}
import akka.pattern.Patterns
import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
......@@ -38,8 +38,7 @@ private[spark] class CoarseGrainedExecutorBackend(
executorId: String,
hostPort: String,
cores: Int,
sparkProperties: Seq[(String, String)],
actorSystem: ActorSystem)
env: SparkEnv)
extends Actor with ActorLogReceive with ExecutorBackend with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
......@@ -58,8 +57,7 @@ private[spark] class CoarseGrainedExecutorBackend(
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
val (hostname, _) = Utils.parseHostPort(hostPort)
executor = new Executor(executorId, hostname, sparkProperties, cores, isLocal = false,
actorSystem)
executor = new Executor(executorId, hostname, env, isLocal = false)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
......@@ -70,7 +68,7 @@ private[spark] class CoarseGrainedExecutorBackend(
logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
val ser = SparkEnv.get.closureSerializer.newInstance()
val ser = env.closureSerializer.newInstance()
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask)
......@@ -128,21 +126,25 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
// Create a new ActorSystem using driver's Spark properties to run the backend.
// Create SparkEnv using properties we fetched from the driver.
val driverConf = new SparkConf().setAll(props)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
SparkEnv.executorActorSystemName,
hostname, port, driverConf, new SecurityManager(driverConf))
// set it
val env = SparkEnv.createExecutorEnv(
driverConf, executorId, hostname, port, cores, isLocal = false)
// SparkEnv sets spark.driver.port so it shouldn't be 0 anymore.
val boundPort = env.conf.getInt("spark.executor.port", 0)
assert(boundPort != 0)
// Start the CoarseGrainedExecutorBackend actor.
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
env.actorSystem.actorOf(
Props(classOf[CoarseGrainedExecutorBackend],
driverUrl, executorId, sparkHostPort, cores, props, actorSystem),
driverUrl, executorId, sparkHostPort, cores, env),
name = "Executor")
workerUrl.foreach { url =>
actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
}
actorSystem.awaitTermination()
env.actorSystem.awaitTermination()
}
}
......
......@@ -26,7 +26,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
import akka.actor.{Props, ActorSystem}
import akka.actor.Props
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
......@@ -42,10 +42,8 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
private[spark] class Executor(
executorId: String,
slaveHostname: String,
properties: Seq[(String, String)],
numCores: Int,
isLocal: Boolean = false,
actorSystem: ActorSystem = null)
env: SparkEnv,
isLocal: Boolean = false)
extends Logging
{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
......@@ -55,6 +53,8 @@ private[spark] class Executor(
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
private val conf = env.conf
@volatile private var isStopped = false
// No ip or host:port - just hostname
......@@ -65,10 +65,6 @@ private[spark] class Executor(
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
// Set spark.* properties from executor arg
val conf = new SparkConf(true)
conf.setAll(properties)
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
......@@ -77,21 +73,11 @@ private[spark] class Executor(
}
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
conf.set("spark.executor.id", executorId)
private val env = {
if (!isLocal) {
val port = conf.getInt("spark.executor.port", 0)
val _env = SparkEnv.createExecutorEnv(
conf, executorId, slaveHostname, port, numCores, isLocal, actorSystem)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
_env.blockManager.initialize(conf.getAppId)
_env
} else {
SparkEnv.get
}
if (!isLocal) {
env.metricsSystem.registerSource(executorSource)
env.blockManager.initialize(conf.getAppId)
}
// Create an actor for receiving RPCs from the driver
......@@ -167,7 +153,7 @@ private[spark] class Executor(
override def run() {
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
......@@ -202,7 +188,7 @@ private[spark] class Executor(
throw new TaskKilledException
}
val resultSer = SparkEnv.get.serializer.newInstance()
val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
......
......@@ -25,7 +25,7 @@ import org.apache.mesos.protobuf.ByteString
import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver, MesosNativeLibrary}
import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
import org.apache.spark.{Logging, TaskState}
import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.{SignalLogger, Utils}
......@@ -64,11 +64,15 @@ private[spark] class MesosExecutorBackend
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++
Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue))
val conf = new SparkConf(loadDefaults = true).setAll(properties)
val port = conf.getInt("spark.executor.port", 0)
val env = SparkEnv.createExecutorEnv(
conf, executorId, slaveInfo.getHostname, port, cpusPerTask, isLocal = false)
executor = new Executor(
executorId,
slaveInfo.getHostname,
properties,
cpusPerTask)
env)
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
......
......@@ -41,17 +41,18 @@ private case class StopExecutor()
* and the TaskSchedulerImpl.
*/
private[spark] class LocalActor(
scheduler: TaskSchedulerImpl,
executorBackend: LocalBackend,
private val totalCores: Int) extends Actor with ActorLogReceive with Logging {
scheduler: TaskSchedulerImpl,
executorBackend: LocalBackend,
private val totalCores: Int)
extends Actor with ActorLogReceive with Logging {
private var freeCores = totalCores
private val localExecutorId = SparkContext.DRIVER_IDENTIFIER
private val localExecutorHostname = "localhost"
val executor = new Executor(
localExecutorId, localExecutorHostname, scheduler.conf.getAll, totalCores, isLocal = true)
private val executor = new Executor(
localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true)
override def receiveWithLogging = {
case ReviveOffers =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment