Skip to content
Snippets Groups Projects
Commit 12dc385a authored by jerryshao's avatar jerryshao
Browse files

Add Spark multi-user support for standalone mode and Mesos

parent aadeda5e
No related branches found
No related tags found
No related merge requests found
...@@ -145,6 +145,14 @@ class SparkContext( ...@@ -145,6 +145,14 @@ class SparkContext(
executorEnvs ++= environment executorEnvs ++= environment
} }
// Set SPARK_USER for user who is running SparkContext.
val sparkUser = Option {
Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
}.getOrElse {
SparkContext.SPARK_UNKNOWN_USER
}
executorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler // Create and start the scheduler
private[spark] var taskScheduler: TaskScheduler = { private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format // Regular expression used for local[N] master format
...@@ -981,6 +989,8 @@ object SparkContext { ...@@ -981,6 +989,8 @@ object SparkContext {
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
private[spark] val SPARK_UNKNOWN_USER = "<unknown>"
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0 def zero(initialValue: Double) = 0.0
......
...@@ -17,8 +17,11 @@ ...@@ -17,8 +17,11 @@
package org.apache.spark.deploy package org.apache.spark.deploy
import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.SparkException import org.apache.spark.SparkException
...@@ -27,6 +30,15 @@ import org.apache.spark.SparkException ...@@ -27,6 +30,15 @@ import org.apache.spark.SparkException
*/ */
private[spark] private[spark]
class SparkHadoopUtil { class SparkHadoopUtil {
val conf = newConfiguration()
UserGroupInformation.setConfiguration(conf)
def runAsUser(user: String)(func: () => Unit) {
val ugi = UserGroupInformation.createRemoteUser(user)
ugi.doAs(new PrivilegedExceptionAction[Unit] {
def run: Unit = func()
})
}
/** /**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
...@@ -42,9 +54,9 @@ class SparkHadoopUtil { ...@@ -42,9 +54,9 @@ class SparkHadoopUtil {
def isYarnMode(): Boolean = { false } def isYarnMode(): Boolean = { false }
} }
object SparkHadoopUtil { object SparkHadoopUtil {
private val hadoop = { private val hadoop = {
val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) { if (yarnMode) {
try { try {
...@@ -56,7 +68,7 @@ object SparkHadoopUtil { ...@@ -56,7 +68,7 @@ object SparkHadoopUtil {
new SparkHadoopUtil new SparkHadoopUtil
} }
} }
def get: SparkHadoopUtil = { def get: SparkHadoopUtil = {
hadoop hadoop
} }
......
...@@ -25,8 +25,9 @@ import java.util.concurrent._ ...@@ -25,8 +25,9 @@ import java.util.concurrent._
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap import scala.collection.mutable.HashMap
import org.apache.spark.scheduler._
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
...@@ -129,6 +130,8 @@ private[spark] class Executor( ...@@ -129,6 +130,8 @@ private[spark] class Executor(
// Maintains the list of running tasks. // Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId, serializedTask) val tr = new TaskRunner(context, taskId, serializedTask)
runningTasks.put(taskId, tr) runningTasks.put(taskId, tr)
...@@ -176,7 +179,7 @@ private[spark] class Executor( ...@@ -176,7 +179,7 @@ private[spark] class Executor(
} }
} }
override def run() { override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () =>
val startTime = System.currentTimeMillis() val startTime = System.currentTimeMillis()
SparkEnv.set(env) SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader) Thread.currentThread.setContextClassLoader(replClassLoader)
......
...@@ -91,8 +91,10 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers ...@@ -91,8 +91,10 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger) sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect() rdd.reduceByKey(_+_).collect()
val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER)
joblogger.getLogDir should be ("/tmp/spark") joblogger.getLogDir should be ("/tmp/spark-%s".format(user))
joblogger.getJobIDtoPrintWriter.size should be (1) joblogger.getJobIDtoPrintWriter.size should be (1)
joblogger.getStageIDToJobID.size should be (2) joblogger.getStageIDToJobID.size should be (2)
joblogger.getStageIDToJobID.get(0) should be (Some(0)) joblogger.getStageIDToJobID.get(0) should be (Some(0))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment