Skip to content
Snippets Groups Projects
Commit c5e2810d authored by Josh Rosen's avatar Josh Rosen
Browse files

Add persist(), splits(), glom(), and mapPartitions() to Java API.

parent 2a60c998
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ import spark.SparkContext.doubleRDDToDoubleRDDFunctions
import spark.api.java.function.{Function => JFunction}
import spark.util.StatCounter
import spark.partial.{BoundedDouble, PartialResult}
import spark.storage.StorageLevel
import java.lang.Double
......@@ -23,6 +24,8 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()
......
......@@ -5,13 +5,13 @@ import spark.api.java.function.{Function2 => JFunction2}
import spark.api.java.function.{Function => JFunction}
import spark.partial.BoundedDouble
import spark.partial.PartialResult
import spark.storage.StorageLevel
import spark._
import java.util.{List => JList}
import java.util.Comparator
import scala.Tuple2
import scala.collection.Map
import scala.collection.JavaConversions._
import org.apache.hadoop.mapred.JobConf
......@@ -33,6 +33,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
// Transformations (return a new RDD)
def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct())
......
......@@ -2,6 +2,7 @@ package spark.api.java
import spark._
import spark.api.java.function.{Function => JFunction}
import spark.storage.StorageLevel
class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends
JavaRDDLike[T, JavaRDD[T]] {
......@@ -12,6 +13,8 @@ JavaRDDLike[T, JavaRDD[T]] {
def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
// Transformations (return a new RDD)
def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct())
......
......@@ -9,7 +9,7 @@ import spark.storage.StorageLevel
import java.util.{List => JList}
import scala.collection.JavaConversions._
import java.lang
import java.{util, lang}
import scala.Tuple2
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
......@@ -19,6 +19,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def rdd: RDD[T]
def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq)
def context: SparkContext = rdd.context
def id: Int = rdd.id
......@@ -56,9 +58,28 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
new JavaPairRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
}
def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue()))
}
def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]):
JavaPairRDD[K, V] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
}
def glom(): JavaRDD[JList[T]] =
new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
other.classManifest)
......
......@@ -21,6 +21,7 @@ import spark.api.java.JavaSparkContext;
import spark.api.java.function.*;
import spark.partial.BoundedDouble;
import spark.partial.PartialResult;
import spark.storage.StorageLevel;
import spark.util.StatCounter;
import java.io.File;
......@@ -337,6 +338,55 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(11, pairs.count());
}
@Test
public void mapPartitions() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
JavaRDD<Integer> partitionSums = rdd.mapPartitions(
new FlatMapFunction<Iterator<Integer>, Integer>() {
@Override
public Iterable<Integer> apply(Iterator<Integer> iter) {
int sum = 0;
while (iter.hasNext()) {
sum += iter.next();
}
return Collections.singletonList(sum);
}
});
Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
}
@Test
public void persist() {
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
Assert.assertEquals(20, doubleRDD.sum(), 0.1);
List<Tuple2<Integer, String>> pairs = Arrays.asList(
new Tuple2<Integer, String>(1, "a"),
new Tuple2<Integer, String>(2, "aa"),
new Tuple2<Integer, String>(3, "aaa")
);
JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY());
Assert.assertEquals("a", pairRDD.first()._2());
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
rdd = rdd.persist(StorageLevel.DISK_ONLY());
Assert.assertEquals(1, rdd.first().intValue());
}
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0)).next().intValue());
}
@Test
public void glom() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
Assert.assertEquals("[1, 2]", rdd.glom().first().toString());
}
// File input / output tests are largely adapted from FileSuite:
@Test
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment