Skip to content
Snippets Groups Projects
Commit d295ccb4 authored by Reynold Xin's avatar Reynold Xin
Browse files

Added a closureSerializer field in SparkEnv and use it to serialize

tasks.
parent 968f75f6
No related branches found
No related tags found
No related merge requests found
...@@ -57,16 +57,17 @@ class Executor extends org.apache.mesos.Executor with Logging { ...@@ -57,16 +57,17 @@ class Executor extends org.apache.mesos.Executor with Logging {
extends Runnable { extends Runnable {
override def run() = { override def run() = {
val tid = desc.getTaskId.getValue val tid = desc.getTaskId.getValue
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + tid) logInfo("Running task ID " + tid)
d.sendStatusUpdate(TaskStatus.newBuilder() d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(desc.getTaskId) .setTaskId(desc.getTaskId)
.setState(TaskState.TASK_RUNNING) .setState(TaskState.TASK_RUNNING)
.build()) .build())
try { try {
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear Accumulators.clear
val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader) val task = ser.deserialize[Task[Any]](desc.getData.toByteArray, classLoader)
for (gen <- task.generation) {// Update generation if any is set for (gen <- task.generation) {// Update generation if any is set
env.mapOutputTracker.updateGeneration(gen) env.mapOutputTracker.updateGeneration(gen)
} }
...@@ -76,7 +77,7 @@ class Executor extends org.apache.mesos.Executor with Logging { ...@@ -76,7 +77,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder() d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(desc.getTaskId) .setTaskId(desc.getTaskId)
.setState(TaskState.TASK_FINISHED) .setState(TaskState.TASK_FINISHED)
.setData(ByteString.copyFrom(Utils.serialize(result))) .setData(ByteString.copyFrom(ser.serialize(result)))
.build()) .build())
logInfo("Finished task ID " + tid) logInfo("Finished task ID " + tid)
} catch { } catch {
...@@ -85,7 +86,7 @@ class Executor extends org.apache.mesos.Executor with Logging { ...@@ -85,7 +86,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder() d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(desc.getTaskId) .setTaskId(desc.getTaskId)
.setState(TaskState.TASK_FAILED) .setState(TaskState.TASK_FAILED)
.setData(ByteString.copyFrom(Utils.serialize(reason))) .setData(ByteString.copyFrom(ser.serialize(reason)))
.build()) .build())
} }
case t: Throwable => { case t: Throwable => {
...@@ -93,7 +94,7 @@ class Executor extends org.apache.mesos.Executor with Logging { ...@@ -93,7 +94,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
d.sendStatusUpdate(TaskStatus.newBuilder() d.sendStatusUpdate(TaskStatus.newBuilder()
.setTaskId(desc.getTaskId) .setTaskId(desc.getTaskId)
.setState(TaskState.TASK_FAILED) .setState(TaskState.TASK_FAILED)
.setData(ByteString.copyFrom(Utils.serialize(reason))) .setData(ByteString.copyFrom(ser.serialize(reason)))
.build()) .build())
// TODO: Handle errors in tasks less dramatically // TODO: Handle errors in tasks less dramatically
......
...@@ -38,9 +38,13 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule ...@@ -38,9 +38,13 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
// Serialize and deserialize the task so that accumulators are changed to thread-local ones; // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works. // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
Accumulators.clear Accumulators.clear
val bytes = Utils.serialize(task) val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes") val startTime = System.currentTimeMillis
val deserializedTask = Utils.deserialize[Task[_]]( val bytes = ser.serialize(task)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Size of task %d is %d bytes and took %d ms to serialize by %s"
.format(idInJob, bytes.size, timeTaken, ser.getClass.getName))
val deserializedTask = ser.deserialize[Task[_]](
bytes, Thread.currentThread.getContextClassLoader) bytes, Thread.currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId) val result: Any = deserializedTask.run(attemptId)
val accumUpdates = Accumulators.values val accumUpdates = Accumulators.values
......
...@@ -30,6 +30,9 @@ class SimpleJob( ...@@ -30,6 +30,9 @@ class SimpleJob(
// Maximum times a task is allowed to fail before failing the job // Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4 val MAX_TASK_FAILURES = 4
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
val callingThread = Thread.currentThread val callingThread = Thread.currentThread
val tasks = tasksSeq.toArray val tasks = tasksSeq.toArray
val numTasks = tasks.length val numTasks = tasks.length
...@@ -170,8 +173,14 @@ class SimpleJob( ...@@ -170,8 +173,14 @@ class SimpleJob(
.setType(Resource.Type.SCALAR) .setType(Resource.Type.SCALAR)
.setScalar(Resource.Scalar.newBuilder().setValue(CPUS_PER_TASK).build()) .setScalar(Resource.Scalar.newBuilder().setValue(CPUS_PER_TASK).build())
.build() .build()
val serializedTask = Utils.serialize(task)
logDebug("Serialized size: " + serializedTask.size) val startTime = System.currentTimeMillis
val serializedTask = ser.serialize(task)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s"
.format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName))
val taskName = "task %d:%d".format(jobId, index) val taskName = "task %d:%d".format(jobId, index)
return Some(TaskDescription.newBuilder() return Some(TaskDescription.newBuilder()
.setTaskId(taskId) .setTaskId(taskId)
...@@ -208,7 +217,8 @@ class SimpleJob( ...@@ -208,7 +217,8 @@ class SimpleJob(
tasksFinished += 1 tasksFinished += 1
logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks)) logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks))
// Deserialize task result // Deserialize task result
val result = Utils.deserialize[TaskResult[_]](status.getData.toByteArray) val result = ser.deserialize[TaskResult[_]](
status.getData.toByteArray)
sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates) sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
// Mark finished and stop if we've finished all the tasks // Mark finished and stop if we've finished all the tasks
finished(index) = true finished(index) = true
...@@ -230,7 +240,8 @@ class SimpleJob( ...@@ -230,7 +240,8 @@ class SimpleJob(
// Check if the problem is a map output fetch failure. In that case, this // Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it. // task will never succeed on any node, so tell the scheduler about it.
if (status.getData != null && status.getData.size > 0) { if (status.getData != null && status.getData.size > 0) {
val reason = Utils.deserialize[TaskEndReason](status.getData.toByteArray) val reason = ser.deserialize[TaskEndReason](
status.getData.toByteArray)
reason match { reason match {
case fetchFailed: FetchFailed => case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri) logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri)
......
...@@ -3,6 +3,7 @@ package spark ...@@ -3,6 +3,7 @@ package spark
class SparkEnv ( class SparkEnv (
val cache: Cache, val cache: Cache,
val serializer: Serializer, val serializer: Serializer,
val closureSerializer: Serializer,
val cacheTracker: CacheTracker, val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker, val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher, val shuffleFetcher: ShuffleFetcher,
...@@ -27,6 +28,11 @@ object SparkEnv { ...@@ -27,6 +28,11 @@ object SparkEnv {
val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer") val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
val cacheTracker = new CacheTracker(isMaster, cache) val cacheTracker = new CacheTracker(isMaster, cache)
val mapOutputTracker = new MapOutputTracker(isMaster) val mapOutputTracker = new MapOutputTracker(isMaster)
...@@ -38,6 +44,13 @@ object SparkEnv { ...@@ -38,6 +44,13 @@ object SparkEnv {
val shuffleMgr = new ShuffleManager() val shuffleMgr = new ShuffleManager()
new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher, shuffleMgr) new SparkEnv(
cache,
serializer,
closureSerializer,
cacheTracker,
mapOutputTracker,
shuffleFetcher,
shuffleMgr)
} }
} }
...@@ -13,16 +13,27 @@ import scala.util.Random ...@@ -13,16 +13,27 @@ import scala.util.Random
*/ */
object Utils { object Utils {
// The serializer in this object is used by Spark to serialize closures. def serialize[T](o: T): Array[Byte] = {
val serializerClass = System.getProperty("spark.closure.serializer", "spark.JavaSerializer") val bos = new ByteArrayOutputStream()
val ser = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] val oos = new ObjectOutputStream(bos)
oos.writeObject(o)
oos.close
return bos.toByteArray
}
def serialize[T](o: T): Array[Byte] = ser.newInstance().serialize[T](o) def deserialize[T](bytes: Array[Byte]): T = {
val bis = new ByteArrayInputStream(bytes)
def deserialize[T](bytes: Array[Byte]): T = ser.newInstance().deserialize[T](bytes) val ois = new ObjectInputStream(bis)
return ois.readObject.asInstanceOf[T]
}
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
ser.newInstance().deserialize[T](bytes, loader) val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
return ois.readObject.asInstanceOf[T]
} }
def isAlpha(c: Char): Boolean = { def isAlpha(c: Char): Boolean = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment