From 72f6dbf7b0c8b271f5f9c762374422c69c8ab43d Mon Sep 17 00:00:00 2001
From: EugenCepoi <cepoi.eugen@gmail.com>
Date: Mon, 31 Aug 2015 13:24:35 -0500
Subject: [PATCH] [SPARK-8730] Fixes - Deser objects containing a primitive
 class attribute

Author: EugenCepoi <cepoi.eugen@gmail.com>

Closes #7122 from EugenCepoi/master.
---
 .../spark/serializer/JavaSerializer.scala     | 27 +++++++++++++++----
 .../serializer/JavaSerializerSuite.scala      | 18 +++++++++++++
 2 files changed, 40 insertions(+), 5 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 4a5274b46b..b463a71d5b 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -62,17 +62,34 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa
   extends DeserializationStream {
 
   private val objIn = new ObjectInputStream(in) {
-    override def resolveClass(desc: ObjectStreamClass): Class[_] = {
-      // scalastyle:off classforname
-      Class.forName(desc.getName, false, loader)
-      // scalastyle:on classforname
-    }
+    override def resolveClass(desc: ObjectStreamClass): Class[_] =
+      try {
+        // scalastyle:off classforname
+        Class.forName(desc.getName, false, loader)
+        // scalastyle:on classforname
+      } catch {
+        case e: ClassNotFoundException =>
+          JavaDeserializationStream.primitiveMappings.get(desc.getName).getOrElse(throw e)
+      }
   }
 
   def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
   def close() { objIn.close() }
 }
 
+private object JavaDeserializationStream {
+  val primitiveMappings = Map[String, Class[_]](
+    "boolean" -> classOf[Boolean],
+    "byte" -> classOf[Byte],
+    "char" -> classOf[Char],
+    "short" -> classOf[Short],
+    "int" -> classOf[Int],
+    "long" -> classOf[Long],
+    "float" -> classOf[Float],
+    "double" -> classOf[Double],
+    "void" -> classOf[Void]
+  )
+}
 
 private[spark] class JavaSerializerInstance(
     counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader)
diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
index 329a2b6dad..20f45670bc 100644
--- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
@@ -25,4 +25,22 @@ class JavaSerializerSuite extends SparkFunSuite {
     val instance = serializer.newInstance()
     instance.deserialize[JavaSerializer](instance.serialize(serializer))
   }
+
+  test("Deserialize object containing a primitive Class as attribute") {
+    val serializer = new JavaSerializer(new SparkConf())
+    val instance = serializer.newInstance()
+    instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass()))
+  }
+}
+
+private class ContainsPrimitiveClass extends Serializable {
+  val intClass = classOf[Int]
+  val longClass = classOf[Long]
+  val shortClass = classOf[Short]
+  val charClass = classOf[Char]
+  val doubleClass = classOf[Double]
+  val floatClass = classOf[Float]
+  val booleanClass = classOf[Boolean]
+  val byteClass = classOf[Byte]
+  val voidClass = classOf[Void]
 }
-- 
GitLab