Skip to content
Snippets Groups Projects
Commit 9f20b6b4 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Added reduceByKey operation for RDDs containing pairs

parent 34eccedb
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import java.util.Random
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
import mesos._
......@@ -27,7 +28,12 @@ abstract class RDD[T: ClassManifest](
def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f))
def aggregateSplit() = new SplitRDD(this)
def cache() = new CachedRDD(this)
def sample(withReplacement: Boolean, frac: Double, seed: Int) = new SampledRDD(this, withReplacement, frac, seed)
def sample(withReplacement: Boolean, frac: Double, seed: Int) =
new SampledRDD(this, withReplacement, frac, seed)
def flatMap[U: ClassManifest](f: T => Traversable[U]) =
new FlatMappedRDD(this, sc.clean(f))
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
......@@ -140,6 +146,16 @@ extends RDD[T](prev.sparkContext) {
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
}
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T], f: T => Traversable[U])
extends RDD[U](prev.sparkContext) {
override def splits = prev.splits
override def preferredLocations(split: Split) = prev.preferredLocations(split)
override def iterator(split: Split) =
prev.iterator(split).toStream.flatMap(f).iterator
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
}
class SplitRDD[T: ClassManifest](prev: RDD[T])
extends RDD[Array[T]](prev.sparkContext) {
override def splits = prev.splits
......@@ -281,7 +297,7 @@ extends RDD[T](sc) {
s.asInstanceOf[UnionSplit[T]].preferredLocations()
}
@serializable class CartesianSplit(val s1: Split, val s2: Split) extends Split {}
@serializable class CartesianSplit(val s1: Split, val s2: Split) extends Split
@serializable
class CartesianRDD[T: ClassManifest, U:ClassManifest](
......@@ -310,3 +326,18 @@ extends RDD[Pair[T, U]](sc) {
rdd2.taskStarted(currSplit.s2, slot)
}
}
@serializable class PairRDDExtras[K, V](rdd: RDD[(K, V)]) {
def reduceByKey(func: (V, V) => V): Map[K, V] = {
def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
for ((k, v) <- m2) {
m1.get(k) match {
case None => m1(k) = v
case Some(w) => m1(k) = func(w, v)
}
}
return m1
}
rdd.map(pair => HashMap(pair)).reduce(mergeMaps)
}
}
......@@ -85,9 +85,14 @@ object SparkContext {
def add(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
}
implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
def add(t1: Int, t2: Int): Int = t1 + t2
def zero(initialValue: Int) = 0
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
implicit def rddToPairRDDExtras[K, V](rdd: RDD[(K, V)]) =
new PairRDDExtras(rdd)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment