diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ef598ae41b1f69867042e0302c24e27fb7036c15..673f9a810dd7348eca15501287d1807892db6d45 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -33,8 +33,9 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} -import spark.serializer.SerializerInstance +import spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import spark.deploy.SparkHadoopUtil +import java.nio.ByteBuffer /** @@ -68,6 +69,47 @@ private object Utils extends Logging { return ois.readObject.asInstanceOf[T] } + /** Serialize via nested stream using specific serializer */ + def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = { + val osWrapper = ser.serializeStream(new OutputStream { + def write(b: Int) = os.write(b) + + override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len) + }) + try { + f(osWrapper) + } finally { + osWrapper.close() + } + } + + /** Deserialize via nested stream using specific serializer */ + def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = { + val isWrapper = ser.deserializeStream(new InputStream { + def read(): Int = is.read() + + override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len) + }) + try { + f(isWrapper) + } finally { + isWrapper.close() + } + } + + /** + * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. + */ + def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { + if (bb.hasArray) { + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } else { + val bbval = new Array[Byte](bb.remaining()) + bb.get(bbval) + out.write(bbval) + } + } + def isAlpha(c: Char): Boolean = { (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') } diff --git a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala index 16ba0c26f8ee10414d5843b7ecbec520765ec0c0..33079cd53937d99062c2fd6ddc0e04a68476a007 100644 --- a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala @@ -20,13 +20,15 @@ package spark.rdd import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer import scala.collection.Map -import spark.{RDD, TaskContext, SparkContext, Partition} +import spark._ +import java.io._ +import scala.Serializable private[spark] class ParallelCollectionPartition[T: ClassManifest]( - val rddId: Long, - val slice: Int, - values: Seq[T]) - extends Partition with Serializable { + var rddId: Long, + var slice: Int, + var values: Seq[T]) + extends Partition with Serializable { def iterator: Iterator[T] = values.iterator @@ -37,15 +39,49 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest]( case _ => false } - override val index: Int = slice + override def index: Int = slice + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + + val sfactory = SparkEnv.get.serializer + + // Treat java serializer with default action rather than going thru serialization, to avoid a + // separate serialization header. + + sfactory match { + case js: JavaSerializer => out.defaultWriteObject() + case _ => + out.writeLong(rddId) + out.writeInt(slice) + + val ser = sfactory.newInstance() + Utils.serializeViaNestedStream(out, ser)(_.writeObject(values)) + } + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = { + + val sfactory = SparkEnv.get.serializer + sfactory match { + case js: JavaSerializer => in.defaultReadObject() + case _ => + rddId = in.readLong() + slice = in.readInt() + + val ser = sfactory.newInstance() + Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject()) + } + } } private[spark] class ParallelCollectionRDD[T: ClassManifest]( @transient sc: SparkContext, @transient data: Seq[T], numSlices: Int, - locationPrefs: Map[Int,Seq[String]]) - extends RDD[T](sc, Nil) { + locationPrefs: Map[Int, Seq[String]]) + extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. @@ -82,16 +118,17 @@ private object ParallelCollectionRDD { 1 } slice(new Range( - r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) + r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) } case r: Range => { (0 until numSlices).map(i => { val start = ((i * r.length.toLong) / numSlices).toInt - val end = (((i+1) * r.length.toLong) / numSlices).toInt + val end = (((i + 1) * r.length.toLong) / numSlices).toInt new Range(r.start + start * r.step, r.start + end * r.step, r.step) }).asInstanceOf[Seq[Seq[T]]] } - case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc + case nr: NumericRange[_] => { + // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything var r = nr @@ -102,10 +139,10 @@ private object ParallelCollectionRDD { slices } case _ => { - val array = seq.toArray // To prevent O(n^2) operations for List etc + val array = seq.toArray // To prevent O(n^2) operations for List etc (0 until numSlices).map(i => { val start = ((i * array.length.toLong) / numSlices).toInt - val end = (((i+1) * array.length.toLong) / numSlices).toInt + val end = (((i + 1) * array.length.toLong) / numSlices).toInt array.slice(start, end).toSeq }) } diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala index dc0621ea7ba5fb20ec6be60627d20d602790a1de..89793e0e8287839f62300488c0fe7ccdbcc25885 100644 --- a/core/src/main/scala/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/spark/scheduler/TaskResult.scala @@ -21,6 +21,8 @@ import java.io._ import scala.collection.mutable.Map import spark.executor.TaskMetrics +import spark.{Utils, SparkEnv} +import java.nio.ByteBuffer // Task result. Also contains updates to accumulator variables. // TODO: Use of distributed cache to return result is a hack to get around @@ -30,7 +32,13 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: def this() = this(null.asInstanceOf[T], null, null) override def writeExternal(out: ObjectOutput) { - out.writeObject(value) + + val objectSer = SparkEnv.get.serializer.newInstance() + val bb = objectSer.serialize(value) + + out.writeInt(bb.remaining()) + Utils.writeByteBuffer(bb, out) + out.writeInt(accumUpdates.size) for ((key, value) <- accumUpdates) { out.writeLong(key) @@ -40,7 +48,14 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: } override def readExternal(in: ObjectInput) { - value = in.readObject().asInstanceOf[T] + + val objectSer = SparkEnv.get.serializer.newInstance() + + val blen = in.readInt() + val byteVal = new Array[Byte](blen) + in.readFully(byteVal) + value = objectSer.deserialize(ByteBuffer.wrap(byteVal)) + val numUpdates = in.readInt if (numUpdates == 0) { accumUpdates = null diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala index d2110bd098e59b3963ff0819235f15c1ab5d0d3e..7f855cd345b7f6ab4ea600d0c9aed39f86574f87 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -92,7 +92,8 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble // Serializer for closures and tasks. - val ser = SparkEnv.get.closureSerializer.newInstance() + val env = SparkEnv.get + val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks val numTasks = tasks.length @@ -534,6 +535,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: } override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + SparkEnv.set(env) state match { case TaskState.FINISHED => taskFinished(tid, state, serializedData) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index bb0c836e86b2a39c6c0f39082806b77ce3265971..f274b1a767984524df7b0fba6d92f9045c6ce312 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -169,7 +169,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: // Set the Spark execution environment for the worker thread SparkEnv.set(env) val ser = SparkEnv.get.closureSerializer.newInstance() - var attemptedTask: Option[Task[_]] = None + val objectSer = SparkEnv.get.serializer.newInstance() + var attemptedTask: Option[Task[_]] = None val start = System.currentTimeMillis() var taskStart: Long = 0 try { @@ -193,9 +194,9 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: // executor does. This is useful to catch serialization errors early // on in development (so when users move their local Spark programs // to the cluster, they don't get surprised by serialization errors). - val serResult = ser.serialize(result) + val serResult = objectSer.serialize(result) deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = ser.deserialize[Any](serResult) + val resultToReturn = objectSer.deserialize[Any](serResult) val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) val serviceTime = System.currentTimeMillis() - taskStart diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala index 4ab15532cf8ac03ed2c5dc386fc5ad39498e0181..c38eeb9e11eb96b23fad0efd906cfc927ce179af 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -42,7 +42,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas val taskInfos = new HashMap[Long, TaskInfo] val numTasks = taskSet.tasks.size var numFinished = 0 - val ser = SparkEnv.get.closureSerializer.newInstance() + val env = SparkEnv.get + val ser = env.closureSerializer.newInstance() val copiesRunning = new Array[Int](numTasks) val finished = new Array[Boolean](numTasks) val numFailures = new Array[Int](numTasks) @@ -143,6 +144,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + SparkEnv.set(env) state match { case TaskState.FINISHED => taskEnded(tid, state, serializedData) diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala index 30d2d5282bc84bf887d25c5233020be9e7d06499..01390027c8de1bb5cfd5f1d8f8cc3f4a4c2a7d55 100644 --- a/core/src/test/scala/spark/KryoSerializerSuite.scala +++ b/core/src/test/scala/spark/KryoSerializerSuite.scala @@ -22,7 +22,9 @@ import scala.collection.mutable import org.scalatest.FunSuite import com.esotericsoftware.kryo._ -class KryoSerializerSuite extends FunSuite { +import KryoTest._ + +class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("basic types") { val ser = (new KryoSerializer).newInstance() def check[T](t: T) { @@ -124,6 +126,45 @@ class KryoSerializerSuite extends FunSuite { System.clearProperty("spark.kryo.registrator") } + + test("kryo with collect") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x) + assert(control === result.toSeq) + } + + test("kryo with parallelize") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect() + assert (control === result.toSeq) + } + + test("kryo with reduce") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(control.sum === result) + } + + // TODO: this still doesn't work + ignore("kryo with fold") { + val control = 1 :: 2 :: Nil + val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)) + .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x + assert(10 + control.sum === result) + } + + override def beforeAll() { + System.setProperty("spark.serializer", "spark.KryoSerializer") + System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) + super.beforeAll() + } + + override def afterAll() { + super.afterAll() + System.clearProperty("spark.kryo.registrator") + System.clearProperty("spark.serializer") + } } object KryoTest {