diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index feeb6c02caa7836114ecd7045ebe2b07733b71f9..39925db77f60c5d06b8a7b39613495b661f69735 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -35,6 +35,7 @@ import org.apache.spark.Partitioner._
 import org.apache.spark.SparkContext.rddToPairRDDFunctions
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
 import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction}
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
@@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
    * before sending results to a reducer, similarly to a "combiner" in MapReduce.
    */
   def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
-    mapAsJavaMap(rdd.reduceByKeyLocally(func))
+    mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func))
 
   /** Count the number of elements for each key, and return the result to the master as a Map. */
-  def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
+  def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey())
 
   /**
    * :: Experimental ::
@@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
    */
   @Experimental
   def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
-    rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
+    rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap)
 
   /**
    * :: Experimental ::
@@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
   @Experimental
   def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
   : PartialResult[java.util.Map[K, BoundedDouble]] =
-    rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
+    rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap)
 
   /**
    * Aggregate the values of each key, using given combine functions and a neutral "zero value".
@@ -566,7 +567,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
   /**
    * Return the key-value pairs in this RDD to the master as a Map.
    */
-  def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
+  def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap())
+
 
   /**
    * Pass each value in the key-value pair RDD through a map function without changing the keys;
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index f917cfd1419ec6c9c9e2b146a54356791a52cd11..d230678238ab92d08c77b3d38989b1cf97ee6bc9 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -30,6 +30,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaPairRDD._
 import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
 import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.rdd.RDD
@@ -390,7 +391,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    * combine step happens locally on the master, equivalent to running a single reduce task.
    */
   def countByValue(): java.util.Map[T, java.lang.Long] =
-    mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
+    mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
 
   /**
    * (Experimental) Approximate version of countByValue().
@@ -399,13 +400,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     timeout: Long,
     confidence: Double
     ): PartialResult[java.util.Map[T, BoundedDouble]] =
-    rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
+    rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap)
 
   /**
    * (Experimental) Approximate version of countByValue().
    */
   def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
-    rdd.countByValueApprox(timeout).map(mapAsJavaMap)
+    rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap)
 
   /**
    * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index 22810cb1c662d3eb517c584b921ca2535ae3270b..b52d0a5028e84b67c967130e0d055c9fb898a708 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,10 +19,20 @@ package org.apache.spark.api.java
 
 import com.google.common.base.Optional
 
+import scala.collection.convert.Wrappers.MapWrapper
+
 private[spark] object JavaUtils {
   def optionToOptional[T](option: Option[T]): Optional[T] =
     option match {
       case Some(value) => Optional.of(value)
       case None => Optional.absent()
     }
+
+  // Workaround for SPARK-3926 / SI-8911
+  def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
+    new SerializableMapWrapper(underlying)
+
+  class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
+    extends MapWrapper(underlying) with java.io.Serializable
+
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
index 6c67934bda5b83a8adfdade61321ebff82629524..0e5966ed8374037ef7818cefe7c6cbf73c997ea8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
 import scala.collection.JavaConversions
 import scala.math.BigDecimal
 
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
 import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow}
 
 /**
@@ -114,7 +115,7 @@ object Row {
     // they are actually accessed.
     case row: ScalaRow => new Row(row)
     case map: scala.collection.Map[_, _] =>
-      JavaConversions.mapAsJavaMap(
+      mapAsSerializableJavaMap(
         map.map {
           case (key, value) => (toJavaValue(key), toJavaValue(value))
         }