Skip to content
Snippets Groups Projects
Commit 2c00ea3e authored by Reynold Xin's avatar Reynold Xin
Browse files

Moved shuffle serializer setting from a constructor parameter to a...

Moved shuffle serializer setting from a constructor parameter to a setSerializer method in various RDDs that involve shuffle operations.
parent 0e84fee7
No related branches found
No related tags found
No related merge requests found
......@@ -85,17 +85,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
self.mapPartitions(aggregator.combineValuesByKey, true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey, true)
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
.setSerializer(serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey, true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
values.mapPartitions(aggregator.combineValuesByKey(_), true)
val values = new ShuffledRDD[K, V](self, partitioner).setSerializer(serializerClass)
values.mapPartitions(aggregator.combineValuesByKey, true)
}
}
......
......@@ -60,12 +60,16 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
* @param rdds parent RDDs.
* @param part partitioner used to partition the shuffle output.
*/
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
val serializerClass: String = null)
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
private var serializerClass: String = null
def setSerializer(cls: String): CoGroupedRDD[K] = {
serializerClass = cls
this
}
override def getDependencies: Seq[Dependency[_]] = {
rdds.map { rdd: RDD[(K, _)] =>
if (rdd.partitioner == Some(part)) {
......
......@@ -17,8 +17,9 @@
package spark.rdd
import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
import spark._
import scala.Some
import scala.Some
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
......@@ -30,15 +31,24 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
* @param serializerClass class name of the serializer to use.
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
@transient prev: RDD[(K, V)],
part: Partitioner,
serializerClass: String = null)
extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part, serializerClass))) {
@transient var prev: RDD[(K, V)],
part: Partitioner)
extends RDD[(K, V)](prev.context, Nil) {
private var serializerClass: String = null
def setSerializer(cls: String): ShuffledRDD[K, V] = {
serializerClass = cls
this
}
override def getDependencies: Seq[Dependency[_]] = {
List(new ShuffleDependency(prev, part, serializerClass))
}
override val partitioner = Some(part)
......@@ -51,4 +61,9 @@ class ShuffledRDD[K, V](
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics,
SparkEnv.get.serializerManager.get(serializerClass))
}
override def clearDependencies() {
super.clearDependencies()
prev = null
}
}
......@@ -49,10 +49,16 @@ import spark.OneToOneDependency
private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
@transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, W)],
part: Partitioner,
val serializerClass: String = null)
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
private var serializerClass: String = null
def setSerializer(cls: String): SubtractedRDD[K, V, W] = {
serializerClass = cls
this
}
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
if (rdd.partitioner == Some(part)) {
......
......@@ -17,17 +17,8 @@
package spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
import com.google.common.io.Files
import spark.rdd.ShuffledRDD
import spark.SparkContext._
......@@ -59,8 +50,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS),
classOf[spark.KryoSerializer].getName)
val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS))
.setSerializer(classOf[spark.KryoSerializer].getName)
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 10)
......@@ -81,7 +72,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD(b, new HashPartitioner(3), classOf[spark.KryoSerializer].getName)
val c = new ShuffledRDD(b, new HashPartitioner(3))
.setSerializer(classOf[spark.KryoSerializer].getName)
assert(c.count === 10)
}
......@@ -96,7 +88,8 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName)
val c = new ShuffledRDD(b, new HashPartitioner(10))
.setSerializer(classOf[spark.KryoSerializer].getName)
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 4)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment