Skip to content
Snippets Groups Projects
Commit 66135a34 authored by jerryshao's avatar jerryshao Committed by Reynold Xin
Browse files

[SPARK-2104] Fix task serializing issues when sort with Java non serializable class

Details can be see in [SPARK-2104](https://issues.apache.org/jira/browse/SPARK-2104). This work is based on Reynold's work, add some unit tests to validate the issue.

@rxin , would you please take a look at this PR, thanks a lot.

Author: jerryshao <saisai.shao@intel.com>

Closes #1245 from jerryshao/SPARK-2104 and squashes the following commits:

c8ee362 [jerryshao] Make field partitions transient
2b41917 [jerryshao] Minor changes
47d763c [jerryshao] Fix task serializing issue when sort with Java non serializable class
parent 7b71a0e0
No related branches found
No related tags found
No related merge requests found
......@@ -17,11 +17,13 @@
package org.apache.spark
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD
import org.apache.spark.util.CollectionsUtils
import org.apache.spark.util.Utils
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.{CollectionsUtils, Utils}
/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
......@@ -96,15 +98,15 @@ class HashPartitioner(partitions: Int) extends Partitioner {
* the value of `partitions`.
*/
class RangePartitioner[K : Ordering : ClassTag, V](
partitions: Int,
@transient partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]],
private val ascending: Boolean = true)
private var ascending: Boolean = true)
extends Partitioner {
private val ordering = implicitly[Ordering[K]]
private var ordering = implicitly[Ordering[K]]
// An array of upper bounds for the first (partitions - 1) partitions
private val rangeBounds: Array[K] = {
private var rangeBounds: Array[K] = {
if (partitions == 1) {
Array()
} else {
......@@ -127,7 +129,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
def numPartitions = rangeBounds.length + 1
private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
def getPartition(key: Any): Int = {
val k = key.asInstanceOf[K]
......@@ -173,4 +175,40 @@ class RangePartitioner[K : Ordering : ClassTag, V](
result = prime * result + ascending.hashCode
result
}
@throws(classOf[IOException])
private def writeObject(out: ObjectOutputStream) {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => out.defaultWriteObject()
case _ =>
out.writeBoolean(ascending)
out.writeObject(ordering)
out.writeObject(binarySearch)
val ser = sfactory.newInstance()
Utils.serializeViaNestedStream(out, ser) { stream =>
stream.writeObject(scala.reflect.classTag[Array[K]])
stream.writeObject(rangeBounds)
}
}
}
@throws(classOf[IOException])
private def readObject(in: ObjectInputStream) {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => in.defaultReadObject()
case _ =>
ascending = in.readBoolean()
ordering = in.readObject().asInstanceOf[Ordering[K]]
binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]
val ser = sfactory.newInstance()
Utils.deserializeViaNestedStream(in, ser) { ds =>
implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
rangeBounds = ds.readObject[Array[K]]()
}
}
}
}
......@@ -206,6 +206,42 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// substracted rdd return results as Tuple2
results(0) should be ((3, 33))
}
test("sort with Java non serializable class - Kryo") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
val conf = new SparkConf()
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.setAppName("test")
.setMaster("local-cluster[2,1,512]")
sc = new SparkContext(conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = b.sortByKey().map(x => x._2)
assert(c.collect() === Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
}
test("sort with Java non serializable class - Java") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
val conf = new SparkConf()
.setAppName("test")
.setMaster("local-cluster[2,1,512]")
sc = new SparkContext(conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
}
// default Java serializer cannot handle the non serializable class.
val thrown = intercept[SparkException] {
b.sortByKey().collect()
}
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))
}
}
object ShuffleSuite {
......@@ -215,5 +251,9 @@ object ShuffleSuite {
x + y
}
class NonJavaSerializableClass(val value: Int)
class NonJavaSerializableClass(val value: Int) extends Comparable[NonJavaSerializableClass] {
override def compareTo(o: NonJavaSerializableClass): Int = {
value - o.value
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment