diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala index 803c0638653c0aba4677e2f68e28883a6f72b053..20865d7d28ea91d557948d8cbe195a3f6b155922 100644 --- a/src/scala/spark/RDD.scala +++ b/src/scala/spark/RDD.scala @@ -82,11 +82,10 @@ abstract class RDD[T: ClassManifest]( try { map(x => 1L).reduce(_+_) } catch { case e: UnsupportedOperationException => 0L } - def union(other: RDD[T]) = new UnionRDD(sc, this, other) + def union(other: RDD[T]) = new UnionRDD(sc, Array(this, other)) def cartesian[U: ClassManifest](other: RDD[U]) = new CartesianRDD(sc, this, other) def ++(other: RDD[T]) = this.union(other) - } @serializable @@ -268,36 +267,27 @@ private object CachedRDD { } @serializable -abstract class UnionSplit[T: ClassManifest] extends Split { - def iterator(): Iterator[T] - def preferredLocations(): Seq[String] - def getId(): String -} - -@serializable -class UnionSplitImpl[T: ClassManifest]( - rdd: RDD[T], split: Split) -extends UnionSplit[T] { - override def iterator() = rdd.iterator(split) - override def preferredLocations() = rdd.preferredLocations(split) - override def getId() = - "UnionSplitImpl(" + split.getId() + ")" +class UnionSplit[T: ClassManifest](rdd: RDD[T], split: Split) +extends Split { + def iterator() = rdd.iterator(split) + def preferredLocations() = rdd.preferredLocations(split) + override def getId() = "UnionSplit(" + split.getId() + ")" } @serializable -class UnionRDD[T: ClassManifest]( - sc: SparkContext, rdd1: RDD[T], rdd2: RDD[T]) +class UnionRDD[T: ClassManifest](sc: SparkContext, rdds: Seq[RDD[T]]) extends RDD[T](sc) { - - @transient val splits_ : Array[UnionSplit[T]] = { - val a1 = rdd1.splits.map(s => new UnionSplitImpl(rdd1, s)) - val a2 = rdd2.splits.map(s => new UnionSplitImpl(rdd2, s)) - (a1 ++ a2).toArray + @transient val splits_ : Array[Split] = { + val splits: Seq[Split] = + for (rdd <- rdds; split <- rdd.splits) + yield new UnionSplit(rdd, split) + splits.toArray } - override def splits = splits_.asInstanceOf[Array[Split]] + override def splits = splits_ - override def iterator(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() + override def iterator(s: Split): Iterator[T] = + s.asInstanceOf[UnionSplit[T]].iterator() override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index 69c3332bb02d360bfaac7d66b5ed64601414b17f..953eac9eba109729a21ef071ffb1cec2353deb9d 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -33,13 +33,17 @@ extends Logging { // Methods for creating RDDs - def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) = + def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int): RDD[T] = new ParallelArray[T](this, seq, numSlices) - def parallelize[T: ClassManifest](seq: Seq[T]): ParallelArray[T] = + def parallelize[T: ClassManifest](seq: Seq[T]): RDD[T] = parallelize(seq, scheduler.numCores) - def textFile(path: String) = new HdfsTextFile(this, path) + def textFile(path: String): RDD[String] = + new HdfsTextFile(this, path) + + def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] = + new UnionRDD(this, rdds) // Methods for creating shared variables