diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index fa9df3a97e72e66cb735afc864582f38300e39ec..0be4b4feb8c360e9812c0d7d7db118f713ab8dcf 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -85,17 +85,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { - self.mapPartitions(aggregator.combineValuesByKey(_), true) + self.mapPartitions(aggregator.combineValuesByKey, true) } else if (mapSideCombine) { - val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) - val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass) - partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) + val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey, true) + val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner) + .setSerializer(serializerClass) + partitioned.mapPartitions(aggregator.combineCombinersByKey, true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) - val values = new ShuffledRDD[K, V](self, partitioner, serializerClass) - values.mapPartitions(aggregator.combineValuesByKey(_), true) + val values = new ShuffledRDD[K, V](self, partitioner).setSerializer(serializerClass) + values.mapPartitions(aggregator.combineValuesByKey, true) } } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 019b12d2d5e43796c52ff6bc1ff165642abbe9fd..c2d95dc06049f8b603abbf9373422dfc4739c882 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -60,12 +60,16 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output. */ -class CoGroupedRDD[K]( - @transient var rdds: Seq[RDD[(K, _)]], - part: Partitioner, - val serializerClass: String = null) +class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { + private var serializerClass: String = null + + def setSerializer(cls: String): CoGroupedRDD[K] = { + serializerClass = cls + this + } + override def getDependencies: Seq[Dependency[_]] = { rdds.map { rdd: RDD[(K, _)] => if (rdd.partitioner == Some(part)) { diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 0137f80953e02a14e537f34539b4f6710682c4f8..bcf7d0d89c4451a1cb15457316a1fdb91fa8a788 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -17,8 +17,9 @@ package spark.rdd -import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext} -import spark.SparkContext._ +import spark._ +import scala.Some +import scala.Some private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { @@ -30,15 +31,24 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * The resulting RDD from a shuffle (e.g. repartitioning of data). * @param prev the parent RDD. * @param part the partitioner used to partition the RDD - * @param serializerClass class name of the serializer to use. * @tparam K the key class. * @tparam V the value class. */ class ShuffledRDD[K, V]( - @transient prev: RDD[(K, V)], - part: Partitioner, - serializerClass: String = null) - extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) { + @transient var prev: RDD[(K, V)], + part: Partitioner) + extends RDD[(K, V)](prev.context, Nil) { + + private var serializerClass: String = null + + def setSerializer(cls: String): ShuffledRDD[K, V] = { + serializerClass = cls + this + } + + override def getDependencies: Seq[Dependency[_]] = { + List(new ShuffleDependency(prev, part, serializerClass)) + } override val partitioner = Some(part) @@ -51,4 +61,9 @@ class ShuffledRDD[K, V]( SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics, SparkEnv.get.serializerManager.get(serializerClass)) } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } } diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala index 0402b9f2501d62e2fa9e49aa096ee193926519be..46b8cafaac3396f2dcf39f87e921b3de19001ca2 100644 --- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala @@ -49,10 +49,16 @@ import spark.OneToOneDependency private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest]( @transient var rdd1: RDD[(K, V)], @transient var rdd2: RDD[(K, W)], - part: Partitioner, - val serializerClass: String = null) + part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { + private var serializerClass: String = null + + def setSerializer(cls: String): SubtractedRDD[K, V, W] = { + serializerClass = cls + this + } + override def getDependencies: Seq[Dependency[_]] = { Seq(rdd1, rdd2).map { rdd => if (rdd.partitioner == Some(part)) { diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 752e4b85e67d2585f128a9e76b754aa49cafee1a..c686b8cc5af43a04dddad6d9256b9339b947d0e8 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -17,17 +17,8 @@ package spark -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet - import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers -import org.scalatest.prop.Checkers -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ - -import com.google.common.io.Files import spark.rdd.ShuffledRDD import spark.SparkContext._ @@ -59,8 +50,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { } // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS), - classOf[spark.KryoSerializer].getName) + val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS)) + .setSerializer(classOf[spark.KryoSerializer].getName) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId assert(c.count === 10) @@ -81,7 +72,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { } // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. - val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName) + val c = new ShuffledRDD(b, new HashPartitioner(3)) + .setSerializer(classOf[spark.KryoSerializer].getName) assert(c.count === 10) } @@ -96,7 +88,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo - val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName) + val c = new ShuffledRDD(b, new HashPartitioner(10)) + .setSerializer(classOf[spark.KryoSerializer].getName) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId assert(c.count === 4)