From e23938c3beaefe7d373c1f82d7b475352bd9c9bb Mon Sep 17 00:00:00 2001
From: Josh Rosen <rosenville@gmail.com>
Date: Sun, 22 Jul 2012 15:10:01 -0700
Subject: [PATCH] Use mapValues() in JavaPairRDD.cogroupResultToJava().

---
 core/src/main/scala/spark/RDD.scala           |  8 -------
 .../scala/spark/api/java/JavaPairRDD.scala    | 23 ++++++++++---------
 .../scala/spark/api/java/JavaRDDLike.scala    |  5 ++--
 3 files changed, 14 insertions(+), 22 deletions(-)

diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index bf94773214..429e9c936f 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -361,14 +361,6 @@ class MappedRDD[U: ClassManifest, T: ClassManifest](
   override def compute(split: Split) = prev.iterator(split).map(f)
 }
 
-class PartitioningPreservingMappedRDD[U: ClassManifest, T: ClassManifest](
-  prev: RDD[T],
-  f: T => U)
-  extends MappedRDD[U, T](prev, f) {
-
-  override val partitioner = prev.partitioner
-}
-
 class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
     prev: RDD[T],
     f: T => TraversableOnce[U])
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index dbd6ce526c..3db7dd8b8f 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -253,17 +253,18 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
 object JavaPairRDD {
   def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K],
     vcm: ClassManifest[T]): RDD[(K, JList[T])] =
-    new PartitioningPreservingMappedRDD(rdd, (x: (K, Seq[T])) => (x._1, seqAsJavaList(x._2)))
-
-  def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])
-  : RDD[(K, (JList[V], JList[W]))] = new PartitioningPreservingMappedRDD(rdd,
-    (x: (K, (Seq[V], Seq[W]))) => (x._1, (seqAsJavaList(x._2._1), seqAsJavaList(x._2._2))))
-
-  def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))])
-  : RDD[(K, (JList[V], JList[W1], JList[W2]))] = new PartitioningPreservingMappedRDD(rdd,
-    (x: (K, (Seq[V], Seq[W1], Seq[W2]))) => (x._1, (seqAsJavaList(x._2._1),
-      seqAsJavaList(x._2._2),
-      seqAsJavaList(x._2._3))))
+    rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _)
+
+  def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K],
+    vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V],
+    Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
+
+  def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1],
+    Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1],
+    JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues(
+    (x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1),
+      seqAsJavaList(x._2),
+      seqAsJavaList(x._3)))
 
   def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
     new JavaPairRDD[K, V](rdd)
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 9f6674df56..5e0cb15042 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -1,6 +1,6 @@
 package spark.api.java
 
-import spark.{PartitioningPreservingMappedRDD, Split, RDD}
+import spark.{Split, RDD}
 import spark.api.java.JavaPairRDD._
 import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
 import spark.partial.{PartialResult, BoundedDouble}
@@ -48,8 +48,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
     import scala.collection.JavaConverters._
     def fn = (x: T) => f.apply(x).asScala
-    new JavaDoubleRDD(new PartitioningPreservingMappedRDD(rdd.flatMap(fn),
-      ((x: java.lang.Double) => x.doubleValue())))
+    new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue()))
   }
 
   def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
-- 
GitLab