From 327404d88308854c5421b304f3a342e87547611e Mon Sep 17 00:00:00 2001
From: Sean Owen <sowen@cloudera.com>
Date: Sat, 18 Oct 2014 12:33:20 -0700
Subject: [PATCH] SPARK-3926 [CORE] Result of JavaRDD.collectAsMap() is not
 Serializable

Make JavaPairRDD.collectAsMap result Serializable since Java Maps generally are

Author: Sean Owen <sowen@cloudera.com>

Closes #2805 from srowen/SPARK-3926 and squashes the following commits:

ecb78ee [Sean Owen] Fix conflict between java.io.Serializable and use of Scala's Serializable
f4717f9 [Sean Owen] Oops, fix compile problem
ae1b36f [Sean Owen] Expand to cover Maps returned from other Java API methods as well
51c26c2 [Sean Owen] Make JavaPairRDD.collectAsMap result Serializable since Java Maps generally are
---
 .../org/apache/spark/api/java/JavaPairRDD.scala      | 12 +++++++-----
 .../org/apache/spark/api/java/JavaRDDLike.scala      |  7 ++++---
 .../scala/org/apache/spark/api/java/JavaUtils.scala  | 10 ++++++++++
 .../scala/org/apache/spark/sql/api/java/Row.scala    |  3 ++-
 4 files changed, 23 insertions(+), 9 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index feeb6c02ca..39925db77f 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -35,6 +35,7 @@ import org.apache.spark.Partitioner._
 import org.apache.spark.SparkContext.rddToPairRDDFunctions
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
 import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction}
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
@@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
    * before sending results to a reducer, similarly to a "combiner" in MapReduce.
    */
   def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
-    mapAsJavaMap(rdd.reduceByKeyLocally(func))
+    mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func))
 
   /** Count the number of elements for each key, and return the result to the master as a Map. */
-  def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
+  def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey())
 
   /**
    * :: Experimental ::
@@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
    */
   @Experimental
   def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
-    rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
+    rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap)
 
   /**
    * :: Experimental ::
@@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
   @Experimental
   def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
   : PartialResult[java.util.Map[K, BoundedDouble]] =
-    rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
+    rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap)
 
   /**
    * Aggregate the values of each key, using given combine functions and a neutral "zero value".
@@ -566,7 +567,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
   /**
    * Return the key-value pairs in this RDD to the master as a Map.
    */
-  def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
+  def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap())
+
 
   /**
    * Pass each value in the key-value pair RDD through a map function without changing the keys;
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index f917cfd141..d230678238 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -30,6 +30,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaPairRDD._
 import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
 import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.rdd.RDD
@@ -390,7 +391,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    * combine step happens locally on the master, equivalent to running a single reduce task.
    */
   def countByValue(): java.util.Map[T, java.lang.Long] =
-    mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
+    mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
 
   /**
    * (Experimental) Approximate version of countByValue().
@@ -399,13 +400,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     timeout: Long,
     confidence: Double
     ): PartialResult[java.util.Map[T, BoundedDouble]] =
-    rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
+    rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap)
 
   /**
    * (Experimental) Approximate version of countByValue().
    */
   def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
-    rdd.countByValueApprox(timeout).map(mapAsJavaMap)
+    rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap)
 
   /**
    * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index 22810cb1c6..b52d0a5028 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,10 +19,20 @@ package org.apache.spark.api.java
 
 import com.google.common.base.Optional
 
+import scala.collection.convert.Wrappers.MapWrapper
+
 private[spark] object JavaUtils {
   def optionToOptional[T](option: Option[T]): Optional[T] =
     option match {
       case Some(value) => Optional.of(value)
       case None => Optional.absent()
     }
+
+  // Workaround for SPARK-3926 / SI-8911
+  def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
+    new SerializableMapWrapper(underlying)
+
+  class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
+    extends MapWrapper(underlying) with java.io.Serializable
+
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
index 6c67934bda..0e5966ed83 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
 import scala.collection.JavaConversions
 import scala.math.BigDecimal
 
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
 import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow}
 
 /**
@@ -114,7 +115,7 @@ object Row {
     // they are actually accessed.
     case row: ScalaRow => new Row(row)
     case map: scala.collection.Map[_, _] =>
-      JavaConversions.mapAsJavaMap(
+      mapAsSerializableJavaMap(
         map.map {
           case (key, value) => (toJavaValue(key), toJavaValue(value))
         }
-- 
GitLab