diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e7f75481939a8aa2cec36543f3b6e54d11b5d2d5..ec99648a8488a6440fa2fbe72bf3ec7fcc17be8b 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -17,11 +17,13 @@ package org.apache.spark +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import scala.reflect.ClassTag import org.apache.spark.rdd.RDD -import org.apache.spark.util.CollectionsUtils -import org.apache.spark.util.Utils +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.{CollectionsUtils, Utils} /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. @@ -96,15 +98,15 @@ class HashPartitioner(partitions: Int) extends Partitioner { * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( - partitions: Int, + @transient partitions: Int, @transient rdd: RDD[_ <: Product2[K,V]], - private val ascending: Boolean = true) + private var ascending: Boolean = true) extends Partitioner { - private val ordering = implicitly[Ordering[K]] + private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions - private val rangeBounds: Array[K] = { + private var rangeBounds: Array[K] = { if (partitions == 1) { Array() } else { @@ -127,7 +129,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( def numPartitions = rangeBounds.length + 1 - private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] + private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] def getPartition(key: Any): Int = { val k = key.asInstanceOf[K] @@ -173,4 +175,40 @@ class RangePartitioner[K : Ordering : ClassTag, V]( result = prime * result + ascending.hashCode result } + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream) { + val sfactory = SparkEnv.get.serializer + sfactory match { + case js: JavaSerializer => out.defaultWriteObject() + case _ => + out.writeBoolean(ascending) + out.writeObject(ordering) + out.writeObject(binarySearch) + + val ser = sfactory.newInstance() + Utils.serializeViaNestedStream(out, ser) { stream => + stream.writeObject(scala.reflect.classTag[Array[K]]) + stream.writeObject(rangeBounds) + } + } + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream) { + val sfactory = SparkEnv.get.serializer + sfactory match { + case js: JavaSerializer => in.defaultReadObject() + case _ => + ascending = in.readBoolean() + ordering = in.readObject().asInstanceOf[Ordering[K]] + binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int] + + val ser = sfactory.newInstance() + Utils.deserializeViaNestedStream(in, ser) { ds => + implicit val classTag = ds.readObject[ClassTag[Array[K]]]() + rangeBounds = ds.readObject[Array[K]]() + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index b40fee7e9ab236889cee630d61d6e9e2b34355ab..c4f2f7e34f4d5905890704debf6303d868b75e43 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -206,6 +206,42 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // substracted rdd return results as Tuple2 results(0) should be ((3, 33)) } + + test("sort with Java non serializable class - Kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + val conf = new SparkConf() + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .setAppName("test") + .setMaster("local-cluster[2,1,512]") + sc = new SparkContext(conf) + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (new NonJavaSerializableClass(x), x) + } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = b.sortByKey().map(x => x._2) + assert(c.collect() === Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + } + + test("sort with Java non serializable class - Java") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + val conf = new SparkConf() + .setAppName("test") + .setMaster("local-cluster[2,1,512]") + sc = new SparkContext(conf) + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (new NonJavaSerializableClass(x), x) + } + // default Java serializer cannot handle the non serializable class. + val thrown = intercept[SparkException] { + b.sortByKey().collect() + } + + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("NotSerializableException")) + } } object ShuffleSuite { @@ -215,5 +251,9 @@ object ShuffleSuite { x + y } - class NonJavaSerializableClass(val value: Int) + class NonJavaSerializableClass(val value: Int) extends Comparable[NonJavaSerializableClass] { + override def compareTo(o: NonJavaSerializableClass): Int = { + value - o.value + } + } }