diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala index 181f7e8b030fe08eed5757a51e48088cb0f813f1..aaf006b6cbdbe2ca334db1a02ff6a97e3a5b2bdc 100644 --- a/src/scala/spark/RDD.scala +++ b/src/scala/spark/RDD.scala @@ -7,6 +7,7 @@ import java.util.Random import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Map +import scala.collection.mutable.HashMap import mesos._ @@ -27,7 +28,12 @@ abstract class RDD[T: ClassManifest]( def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f)) def aggregateSplit() = new SplitRDD(this) def cache() = new CachedRDD(this) - def sample(withReplacement: Boolean, frac: Double, seed: Int) = new SampledRDD(this, withReplacement, frac, seed) + + def sample(withReplacement: Boolean, frac: Double, seed: Int) = + new SampledRDD(this, withReplacement, frac, seed) + + def flatMap[U: ClassManifest](f: T => Traversable[U]) = + new FlatMappedRDD(this, sc.clean(f)) def foreach(f: T => Unit) { val cleanF = sc.clean(f) @@ -140,6 +146,16 @@ extends RDD[T](prev.sparkContext) { override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) } +class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], f: T => Traversable[U]) +extends RDD[U](prev.sparkContext) { + override def splits = prev.splits + override def preferredLocations(split: Split) = prev.preferredLocations(split) + override def iterator(split: Split) = + prev.iterator(split).toStream.flatMap(f).iterator + override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) +} + class SplitRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.sparkContext) { override def splits = prev.splits @@ -281,7 +297,7 @@ extends RDD[T](sc) { s.asInstanceOf[UnionSplit[T]].preferredLocations() } -@serializable class CartesianSplit(val s1: Split, val s2: Split) extends Split {} +@serializable class CartesianSplit(val s1: Split, val s2: Split) extends Split @serializable class CartesianRDD[T: ClassManifest, U:ClassManifest]( @@ -310,3 +326,18 @@ extends RDD[Pair[T, U]](sc) { rdd2.taskStarted(currSplit.s2, slot) } } + +@serializable class PairRDDExtras[K, V](rdd: RDD[(K, V)]) { + def reduceByKey(func: (V, V) => V): Map[K, V] = { + def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = { + for ((k, v) <- m2) { + m1.get(k) match { + case None => m1(k) = v + case Some(w) => m1(k) = func(w, v) + } + } + return m1 + } + rdd.map(pair => HashMap(pair)).reduce(mergeMaps) + } +} diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index e00eda9aa24528e78e62578a1626d04f8c4f33aa..20f04f863966c90a1d94f72e3b8bf6389016d53d 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -85,9 +85,14 @@ object SparkContext { def add(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 } + implicit object IntAccumulatorParam extends AccumulatorParam[Int] { def add(t1: Int, t2: Int): Int = t1 + t2 def zero(initialValue: Int) = 0 } + // TODO: Add AccumulatorParams for other types, e.g. lists and strings + + implicit def rddToPairRDDExtras[K, V](rdd: RDD[(K, V)]) = + new PairRDDExtras(rdd) } diff --git a/src/scala/spark/Utils.scala b/src/scala/spark/Utils.scala index 52bcb89f003fc226c93c8c74d6e2d9bc36efa49f..27d73aefbd69420f0e6aca4ed33e7cb339676cf7 100644 --- a/src/scala/spark/Utils.scala +++ b/src/scala/spark/Utils.scala @@ -2,7 +2,9 @@ package spark import java.io._ -private object Utils { +import scala.collection.mutable.ArrayBuffer + +object Utils { def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream val oos = new ObjectOutputStream(bos) @@ -25,4 +27,27 @@ private object Utils { } return ois.readObject.asInstanceOf[T] } + + def isAlpha(c: Char) = { + (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') + } + + def splitWords(s: String): Seq[String] = { + val buf = new ArrayBuffer[String] + var i = 0 + while (i < s.length) { + var j = i + while (j < s.length && isAlpha(s.charAt(j))) { + j += 1 + } + if (j > i) { + buf += s.substring(i, j); + } + i = j + while (i < s.length && !isAlpha(s.charAt(i))) { + i += 1 + } + } + return buf + } }