diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index e7f75481939a8aa2cec36543f3b6e54d11b5d2d5..ec99648a8488a6440fa2fbe72bf3ec7fcc17be8b 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark
 
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+
 import scala.reflect.ClassTag
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.util.CollectionsUtils
-import org.apache.spark.util.Utils
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.util.{CollectionsUtils, Utils}
 
 /**
  * An object that defines how the elements in a key-value pair RDD are partitioned by key.
@@ -96,15 +98,15 @@ class HashPartitioner(partitions: Int) extends Partitioner {
  * the value of `partitions`.
  */
 class RangePartitioner[K : Ordering : ClassTag, V](
-    partitions: Int,
+    @transient partitions: Int,
     @transient rdd: RDD[_ <: Product2[K,V]],
-    private val ascending: Boolean = true)
+    private var ascending: Boolean = true)
   extends Partitioner {
 
-  private val ordering = implicitly[Ordering[K]]
+  private var ordering = implicitly[Ordering[K]]
 
   // An array of upper bounds for the first (partitions - 1) partitions
-  private val rangeBounds: Array[K] = {
+  private var rangeBounds: Array[K] = {
     if (partitions == 1) {
       Array()
     } else {
@@ -127,7 +129,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
 
   def numPartitions = rangeBounds.length + 1
 
-  private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
+  private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
 
   def getPartition(key: Any): Int = {
     val k = key.asInstanceOf[K]
@@ -173,4 +175,40 @@ class RangePartitioner[K : Ordering : ClassTag, V](
     result = prime * result + ascending.hashCode
     result
   }
+
+  @throws(classOf[IOException])
+  private def writeObject(out: ObjectOutputStream) {
+    val sfactory = SparkEnv.get.serializer
+    sfactory match {
+      case js: JavaSerializer => out.defaultWriteObject()
+      case _ =>
+        out.writeBoolean(ascending)
+        out.writeObject(ordering)
+        out.writeObject(binarySearch)
+
+        val ser = sfactory.newInstance()
+        Utils.serializeViaNestedStream(out, ser) { stream =>
+          stream.writeObject(scala.reflect.classTag[Array[K]])
+          stream.writeObject(rangeBounds)
+        }
+    }
+  }
+
+  @throws(classOf[IOException])
+  private def readObject(in: ObjectInputStream) {
+    val sfactory = SparkEnv.get.serializer
+    sfactory match {
+      case js: JavaSerializer => in.defaultReadObject()
+      case _ =>
+        ascending = in.readBoolean()
+        ordering = in.readObject().asInstanceOf[Ordering[K]]
+        binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]
+
+        val ser = sfactory.newInstance()
+        Utils.deserializeViaNestedStream(in, ser) { ds =>
+          implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
+          rangeBounds = ds.readObject[Array[K]]()
+        }
+    }
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index b40fee7e9ab236889cee630d61d6e9e2b34355ab..c4f2f7e34f4d5905890704debf6303d868b75e43 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -206,6 +206,42 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     // substracted rdd return results as Tuple2
     results(0) should be ((3, 33))
   }
+
+  test("sort with Java non serializable class - Kryo") {
+    // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+    val conf = new SparkConf()
+      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+      .setAppName("test")
+      .setMaster("local-cluster[2,1,512]")
+    sc = new SparkContext(conf)
+    val a = sc.parallelize(1 to 10, 2)
+    val b = a.map { x =>
+      (new NonJavaSerializableClass(x), x)
+    }
+    // If the Kryo serializer is not used correctly, the shuffle would fail because the
+    // default Java serializer cannot handle the non serializable class.
+    val c = b.sortByKey().map(x => x._2)
+    assert(c.collect() === Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+  }
+
+  test("sort with Java non serializable class - Java") {
+    // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+    val conf = new SparkConf()
+      .setAppName("test")
+      .setMaster("local-cluster[2,1,512]")
+    sc = new SparkContext(conf)
+    val a = sc.parallelize(1 to 10, 2)
+    val b = a.map { x =>
+      (new NonJavaSerializableClass(x), x)
+    }
+    // default Java serializer cannot handle the non serializable class.
+    val thrown = intercept[SparkException] {
+      b.sortByKey().collect()
+    }
+
+    assert(thrown.getClass === classOf[SparkException])
+    assert(thrown.getMessage.contains("NotSerializableException"))
+  }
 }
 
 object ShuffleSuite {
@@ -215,5 +251,9 @@ object ShuffleSuite {
     x + y
   }
 
-  class NonJavaSerializableClass(val value: Int)
+  class NonJavaSerializableClass(val value: Int) extends Comparable[NonJavaSerializableClass] {
+    override def compareTo(o: NonJavaSerializableClass): Int = {
+      value - o.value
+    }
+  }
 }