Skip to content
Snippets Groups Projects
Commit 66135a34 authored by jerryshao's avatar jerryshao Committed by Reynold Xin
Browse files

[SPARK-2104] Fix task serializing issues when sort with Java non serializable class

Details can be see in [SPARK-2104](https://issues.apache.org/jira/browse/SPARK-2104). This work is based on Reynold's work, add some unit tests to validate the issue.

@rxin , would you please take a look at this PR, thanks a lot.

Author: jerryshao <saisai.shao@intel.com>

Closes #1245 from jerryshao/SPARK-2104 and squashes the following commits:

c8ee362 [jerryshao] Make field partitions transient
2b41917 [jerryshao] Minor changes
47d763c [jerryshao] Fix task serializing issue when sort with Java non serializable class
parent 7b71a0e0
No related branches found
No related tags found
No related merge requests found
...@@ -17,11 +17,13 @@ ...@@ -17,11 +17,13 @@
package org.apache.spark package org.apache.spark
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.util.CollectionsUtils import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils import org.apache.spark.util.{CollectionsUtils, Utils}
/** /**
* An object that defines how the elements in a key-value pair RDD are partitioned by key. * An object that defines how the elements in a key-value pair RDD are partitioned by key.
...@@ -96,15 +98,15 @@ class HashPartitioner(partitions: Int) extends Partitioner { ...@@ -96,15 +98,15 @@ class HashPartitioner(partitions: Int) extends Partitioner {
* the value of `partitions`. * the value of `partitions`.
*/ */
class RangePartitioner[K : Ordering : ClassTag, V]( class RangePartitioner[K : Ordering : ClassTag, V](
partitions: Int, @transient partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]], @transient rdd: RDD[_ <: Product2[K,V]],
private val ascending: Boolean = true) private var ascending: Boolean = true)
extends Partitioner { extends Partitioner {
private val ordering = implicitly[Ordering[K]] private var ordering = implicitly[Ordering[K]]
// An array of upper bounds for the first (partitions - 1) partitions // An array of upper bounds for the first (partitions - 1) partitions
private val rangeBounds: Array[K] = { private var rangeBounds: Array[K] = {
if (partitions == 1) { if (partitions == 1) {
Array() Array()
} else { } else {
...@@ -127,7 +129,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( ...@@ -127,7 +129,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
def numPartitions = rangeBounds.length + 1 def numPartitions = rangeBounds.length + 1
private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
def getPartition(key: Any): Int = { def getPartition(key: Any): Int = {
val k = key.asInstanceOf[K] val k = key.asInstanceOf[K]
...@@ -173,4 +175,40 @@ class RangePartitioner[K : Ordering : ClassTag, V]( ...@@ -173,4 +175,40 @@ class RangePartitioner[K : Ordering : ClassTag, V](
result = prime * result + ascending.hashCode result = prime * result + ascending.hashCode
result result
} }
@throws(classOf[IOException])
private def writeObject(out: ObjectOutputStream) {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => out.defaultWriteObject()
case _ =>
out.writeBoolean(ascending)
out.writeObject(ordering)
out.writeObject(binarySearch)
val ser = sfactory.newInstance()
Utils.serializeViaNestedStream(out, ser) { stream =>
stream.writeObject(scala.reflect.classTag[Array[K]])
stream.writeObject(rangeBounds)
}
}
}
@throws(classOf[IOException])
private def readObject(in: ObjectInputStream) {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => in.defaultReadObject()
case _ =>
ascending = in.readBoolean()
ordering = in.readObject().asInstanceOf[Ordering[K]]
binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]
val ser = sfactory.newInstance()
Utils.deserializeViaNestedStream(in, ser) { ds =>
implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
rangeBounds = ds.readObject[Array[K]]()
}
}
}
} }
...@@ -206,6 +206,42 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { ...@@ -206,6 +206,42 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// substracted rdd return results as Tuple2 // substracted rdd return results as Tuple2
results(0) should be ((3, 33)) results(0) should be ((3, 33))
} }
test("sort with Java non serializable class - Kryo") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
val conf = new SparkConf()
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.setAppName("test")
.setMaster("local-cluster[2,1,512]")
sc = new SparkContext(conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = b.sortByKey().map(x => x._2)
assert(c.collect() === Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
}
test("sort with Java non serializable class - Java") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
val conf = new SparkConf()
.setAppName("test")
.setMaster("local-cluster[2,1,512]")
sc = new SparkContext(conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
}
// default Java serializer cannot handle the non serializable class.
val thrown = intercept[SparkException] {
b.sortByKey().collect()
}
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
}
} }
object ShuffleSuite { object ShuffleSuite {
...@@ -215,5 +251,9 @@ object ShuffleSuite { ...@@ -215,5 +251,9 @@ object ShuffleSuite {
x + y x + y
} }
class NonJavaSerializableClass(val value: Int) class NonJavaSerializableClass(val value: Int) extends Comparable[NonJavaSerializableClass] {
override def compareTo(o: NonJavaSerializableClass): Int = {
value - o.value
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment