diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 3d1b1ca268c8bdf251a276f6b46feb478f56645d..07efba9e8d26efa2e9a737f28d90e70a665f51d6 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -88,6 +88,33 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) } + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { + combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner) + } + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { + foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) + } + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { + foldByKey(zeroValue, defaultPartitioner(self))(func) + } + /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index c1bd13c49a9e64a0dcbbad35fa2c1e9f02c32d81..49aaabf835648c6d8fa0a080dc86876d073dbb9e 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -160,6 +160,30 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif : PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = + fromRDD(rdd.foldByKey(zeroValue, partitioner)(func)) + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = + fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func)) + + /** + * Merge the values for each key using an associative function and a neutral "zero value" which may + * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for + * list concatenation, 0 for addition, or 1 for multiplication.). + */ + def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = + fromRDD(rdd.foldByKey(zeroValue)(func)) + /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 26e3ab72c0c2b233c31b9f2ecf5dd5ea292af678..d3dcd3bbebe7e472a69516e1101d44e2219753d3 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -196,6 +196,28 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(33, sum); } + @Test + public void foldByKey() { + List<Tuple2<Integer, Integer>> pairs = Arrays.asList( + new Tuple2<Integer, Integer>(2, 1), + new Tuple2<Integer, Integer>(2, 1), + new Tuple2<Integer, Integer>(1, 1), + new Tuple2<Integer, Integer>(3, 2), + new Tuple2<Integer, Integer>(3, 1) + ); + JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs); + JavaPairRDD<Integer, Integer> sums = rdd.foldByKey(0, + new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }); + Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); + Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); + Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); + } + @Test public void reduceByKey() { List<Tuple2<Integer, Integer>> pairs = Arrays.asList(