Skip to content
Snippets Groups Projects
Commit 1289e717 authored by Mark Hamstra's avatar Mark Hamstra
Browse files

refactored _With API and added foreachPartition

parent b57df1f5
No related branches found
No related tags found
No related merge requests found
......@@ -365,60 +365,59 @@ abstract class RDD[T: ClassManifest](
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
/**
* Maps f over this RDD where f takes an additional parameter of type A. This
* additional parameter is produced by a factory method T => A which is called
* on each invocation of f. This factory method is produced by the factoryBuilder,
* an instance of which is constructed in each partition from the partition index
* and a seed value of type B.
*/
def mapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
factoryBuilder: (Int, B) => (T => A),
factorySeed: B,
preservesPartitioning: Boolean = false)
* Maps f over this RDD where, f takes an additional parameter of type A. This
* additional parameter is produced by constructorOfA, which is called in each
* partition with the index of that partition.
*/
def mapWith[A: ClassManifest, U: ClassManifest](constructorOfA: Int => A, preservesPartitioning: Boolean = false)
(f:(A, T) => U): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val factory = factoryBuilder(index, factorySeed)
iter.map(t => f(factory(t), t))
val a = constructorOfA(index)
iter.map(t => f(a, t))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
* FlatMaps f over this RDD where f takes an additional parameter of type A. This
* additional parameter is produced by a factory method T => A which is called
* on each invocation of f. This factory method is produced by the factoryBuilder,
* an instance of which is constructed in each partition from the partition index
* and a seed value of type B.
/**
* FlatMaps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructorOfA, which is called in each
* partition with the index of that partition.
*/
def flatMapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
factoryBuilder: (Int, B) => (T => A),
factorySeed: B,
preservesPartitioning: Boolean = false)
def flatMapWith[A: ClassManifest, U: ClassManifest](constructorOfA: Int => A, preservesPartitioning: Boolean = false)
(f:(A, T) => Seq[U]): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val factory = factoryBuilder(index, factorySeed)
iter.flatMap(t => f(factory(t), t))
val a = constructorOfA(index)
iter.flatMap(t => f(a, t))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
* Applies f to each element of this RDD, where f takes an additional parameter of type A.
* This additional parameter is produced by constructorOfA, which is called in each
* partition with the index of that partition.
*/
def foreachWith[A: ClassManifest](constructorOfA: Int => A)
(f:(A, T) => Unit) {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val a = constructorOfA(index)
iter.map(t => {f(a, t); t})
}
(new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
}
/**
* Filters this RDD with p, where p takes an additional parameter of type A. This
* additional parameter is produced by a factory method T => A which is called
* on each invocation of p. This factory method is produced by the factoryBuilder,
* an instance of which is constructed in each partition from the partition index
* and a seed value of type B.
*/
def filterWith[A: ClassManifest, B: ClassManifest](
factoryBuilder: (Int, B) => (T => A),
factorySeed: B,
preservesPartitioning: Boolean = false)
* additional parameter is produced by constructorOfA, which is called in each
* partition with the index of that partition.
*/
def filterWith[A: ClassManifest](constructorOfA: Int => A)
(p:(A, T) => Boolean): RDD[T] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val factory = factoryBuilder(index, factorySeed)
iter.filter(t => p(factory(t), t))
val a = constructorOfA(index)
iter.filter(t => p(a, t))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
}
/**
......@@ -439,6 +438,14 @@ abstract class RDD[T: ClassManifest](
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => f(iter))
}
/**
* Return an array that contains all of the elements in this RDD.
*/
......
......@@ -180,21 +180,18 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("mapWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.mapWith(
(index: Int, seed: Int) => {
val prng = new java.util.Random(index + seed)
(_ => prng.nextDouble)},
42)
{(random: Double, t: Int) => random * t}.
collect()
(index: Int) => new Random(index + 42))
{(prng: Random, t: Int) => prng.nextDouble * t}.collect()
val prn42_3 = {
val prng42 = new java.util.Random(42)
val prng42 = new Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
val prng43 = new java.util.Random(43)
val prng43 = new Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(2) === prn42_3)
......@@ -202,21 +199,21 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("flatMapWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.flatMapWith(
(index: Int, seed: Int) => {
val prng = new java.util.Random(index + seed)
(_ => prng.nextDouble)},
42)
{(random: Double, t: Int) => Seq(random * t, random * t * 10)}.
(index: Int) => new Random(index + 42))
{(prng: Random, t: Int) => {
val random = prng.nextDouble()
Seq(random * t, random * t * 10)}}.
collect()
val prn42_3 = {
val prng42 = new java.util.Random(42)
val prng42 = new Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
val prng43 = new java.util.Random(43)
val prng43 = new Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(5) === prn42_3 * 10)
......@@ -228,11 +225,8 @@ class RDDSuite extends FunSuite with LocalSparkContext {
sc = new SparkContext("local", "test")
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
val sample = ints.filterWith(
(index: Int, seed: Int) => {
val prng = new Random(index + seed)
(_ => prng.nextInt(3))},
42)
{(random: Int, t: Int) => random == 0}.
(index: Int) => new Random(index + 42))
{(prng: Random, t: Int) => prng.nextInt(3) == 0}.
collect()
val checkSample = {
val prng42 = new Random(42)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment