diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index feeb6c02caa7836114ecd7045ebe2b07733b71f9..39925db77f60c5d06b8a7b39613495b661f69735 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -35,6 +35,7 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} @@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = - mapAsJavaMap(rdd.reduceByKeyLocally(func)) + mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func)) /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) + def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey()) /** * :: Experimental :: @@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ @Experimental def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap) /** * :: Experimental :: @@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -566,7 +567,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. */ - def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap()) + def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap()) + /** * Pass each value in the key-value pair RDD through a map function without changing the keys; diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index f917cfd1419ec6c9c9e2b146a54356791a52cd11..d230678238ab92d08c77b3d38989b1cf97ee6bc9 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -30,6 +30,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD @@ -390,7 +391,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). @@ -399,13 +400,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { timeout: Long, confidence: Double ): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * (Experimental) Approximate version of countByValue(). */ def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout).map(mapAsJavaMap) + rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap) /** * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 22810cb1c662d3eb517c584b921ca2535ae3270b..b52d0a5028e84b67c967130e0d055c9fb898a708 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -19,10 +19,20 @@ package org.apache.spark.api.java import com.google.common.base.Optional +import scala.collection.convert.Wrappers.MapWrapper + private[spark] object JavaUtils { def optionToOptional[T](option: Option[T]): Optional[T] = option match { case Some(value) => Optional.of(value) case None => Optional.absent() } + + // Workaround for SPARK-3926 / SI-8911 + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = + new SerializableMapWrapper(underlying) + + class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) + extends MapWrapper(underlying) with java.io.Serializable + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 6c67934bda5b83a8adfdade61321ebff82629524..0e5966ed8374037ef7818cefe7c6cbf73c997ea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} import scala.collection.JavaConversions import scala.math.BigDecimal +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** @@ -114,7 +115,7 @@ object Row { // they are actually accessed. case row: ScalaRow => new Row(row) case map: scala.collection.Map[_, _] => - JavaConversions.mapAsJavaMap( + mapAsSerializableJavaMap( map.map { case (key, value) => (toJavaValue(key), toJavaValue(value)) }