diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala index 9a90d0af79aa059fe1846dce157ea7fc39d83a3e..7c0b17c45ef6a59216f51e3c6bdb218eb0779a22 100644 --- a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala @@ -5,6 +5,7 @@ import spark.SparkContext.doubleRDDToDoubleRDDFunctions import spark.api.java.function.{Function => JFunction} import spark.util.StatCounter import spark.partial.{BoundedDouble, PartialResult} +import spark.storage.StorageLevel import java.lang.Double @@ -23,6 +24,8 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) + def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel)) + // first() has to be overriden here in order for its return type to be Double instead of Object. override def first(): Double = srdd.first() diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 99d1b1e2088ecbae1b074214b034854ef555578c..c28a13b0619be13593c4258a2a57909dcd24fe68 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -5,13 +5,13 @@ import spark.api.java.function.{Function2 => JFunction2} import spark.api.java.function.{Function => JFunction} import spark.partial.BoundedDouble import spark.partial.PartialResult +import spark.storage.StorageLevel import spark._ import java.util.{List => JList} import java.util.Comparator import scala.Tuple2 -import scala.collection.Map import scala.collection.JavaConversions._ import org.apache.hadoop.mapred.JobConf @@ -33,6 +33,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) + def persist(newLevel: StorageLevel): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.persist(newLevel)) + // Transformations (return a new RDD) def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct()) diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala index 598d4cf15b9a9ccb088abec09d0bac2689b227c4..541aa1e60be94b18528ed9a331b0c18ebe150eaa 100644 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaRDD.scala @@ -2,6 +2,7 @@ package spark.api.java import spark._ import spark.api.java.function.{Function => JFunction} +import spark.storage.StorageLevel class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends JavaRDDLike[T, JavaRDD[T]] { @@ -12,6 +13,8 @@ JavaRDDLike[T, JavaRDD[T]] { def cache(): JavaRDD[T] = wrapRDD(rdd.cache()) + def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel)) + // Transformations (return a new RDD) def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct()) diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 1c6948eb7fbbf5654b9a862585d09992df8a552c..785dd96394ccfb8f64a4f4deeaeebb04f17d3de4 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -9,7 +9,7 @@ import spark.storage.StorageLevel import java.util.{List => JList} import scala.collection.JavaConversions._ -import java.lang +import java.{util, lang} import scala.Tuple2 trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @@ -19,6 +19,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] + def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq) + def context: SparkContext = rdd.context def id: Int = rdd.id @@ -56,9 +58,28 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] - new JavaPairRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType()) + JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType()) + } + + def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType()) + } + + def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) + } + + def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]): + JavaPairRDD[K, V] = { + def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType()) } + def glom(): JavaRDD[JList[T]] = + new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest, other.classManifest) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index f6c0e539e6e4b4bec2938e9095fa9d046fb2b669..436a8ab0c71ac9bb06c8bbc35dc069da0b6a054a 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -21,6 +21,7 @@ import spark.api.java.JavaSparkContext; import spark.api.java.function.*; import spark.partial.BoundedDouble; import spark.partial.PartialResult; +import spark.storage.StorageLevel; import spark.util.StatCounter; import java.io.File; @@ -337,6 +338,55 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(11, pairs.count()); } + @Test + public void mapPartitions() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD<Integer> partitionSums = rdd.mapPartitions( + new FlatMapFunction<Iterator<Integer>, Integer>() { + @Override + public Iterable<Integer> apply(Iterator<Integer> iter) { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum); + } + }); + Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + } + + @Test + public void persist() { + JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); + doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY()); + Assert.assertEquals(20, doubleRDD.sum(), 0.1); + + List<Tuple2<Integer, String>> pairs = Arrays.asList( + new Tuple2<Integer, String>(1, "a"), + new Tuple2<Integer, String>(2, "aa"), + new Tuple2<Integer, String>(3, "aaa") + ); + JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs); + pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); + Assert.assertEquals("a", pairRDD.first()._2()); + + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + rdd = rdd.persist(StorageLevel.DISK_ONLY()); + Assert.assertEquals(1, rdd.first().intValue()); + } + + @Test + public void iterator() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); + Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0)).next().intValue()); + } + + @Test + public void glom() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + Assert.assertEquals("[1, 2]", rdd.glom().first().toString()); + } + // File input / output tests are largely adapted from FileSuite: @Test