diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala
index 803c0638653c0aba4677e2f68e28883a6f72b053..20865d7d28ea91d557948d8cbe195a3f6b155922 100644
--- a/src/scala/spark/RDD.scala
+++ b/src/scala/spark/RDD.scala
@@ -82,11 +82,10 @@ abstract class RDD[T: ClassManifest](
     try { map(x => 1L).reduce(_+_) }
     catch { case e: UnsupportedOperationException => 0L }
 
-  def union(other: RDD[T]) = new UnionRDD(sc, this, other)
+  def union(other: RDD[T]) = new UnionRDD(sc, Array(this, other))
   def cartesian[U: ClassManifest](other: RDD[U]) = new CartesianRDD(sc, this, other)
 
   def ++(other: RDD[T]) = this.union(other)
-
 }
 
 @serializable
@@ -268,36 +267,27 @@ private object CachedRDD {
 }
 
 @serializable
-abstract class UnionSplit[T: ClassManifest] extends Split {
-  def iterator(): Iterator[T]
-  def preferredLocations(): Seq[String]
-  def getId(): String
-}
-
-@serializable
-class UnionSplitImpl[T: ClassManifest](
-  rdd: RDD[T], split: Split)
-extends UnionSplit[T] {
-  override def iterator() = rdd.iterator(split)
-  override def preferredLocations() = rdd.preferredLocations(split)
-  override def getId() =
-    "UnionSplitImpl(" + split.getId() + ")"
+class UnionSplit[T: ClassManifest](rdd: RDD[T], split: Split)
+extends Split {
+  def iterator() = rdd.iterator(split)
+  def preferredLocations() = rdd.preferredLocations(split)
+  override def getId() = "UnionSplit(" + split.getId() + ")"
 }
 
 @serializable
-class UnionRDD[T: ClassManifest](
-  sc: SparkContext, rdd1: RDD[T], rdd2: RDD[T])
+class UnionRDD[T: ClassManifest](sc: SparkContext, rdds: Seq[RDD[T]])
 extends RDD[T](sc) {
-
-  @transient val splits_ : Array[UnionSplit[T]] = {
-    val a1 = rdd1.splits.map(s => new UnionSplitImpl(rdd1, s))
-    val a2 = rdd2.splits.map(s => new UnionSplitImpl(rdd2, s))
-    (a1 ++ a2).toArray
+  @transient val splits_ : Array[Split] = {
+    val splits: Seq[Split] = 
+      for (rdd <- rdds; split <- rdd.splits)
+        yield new UnionSplit(rdd, split)
+    splits.toArray
   }
 
-  override def splits = splits_.asInstanceOf[Array[Split]]
+  override def splits = splits_
 
-  override def iterator(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator()
+  override def iterator(s: Split): Iterator[T] =
+    s.asInstanceOf[UnionSplit[T]].iterator()
 
   override def preferredLocations(s: Split): Seq[String] = 
     s.asInstanceOf[UnionSplit[T]].preferredLocations()
diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala
index 69c3332bb02d360bfaac7d66b5ed64601414b17f..953eac9eba109729a21ef071ffb1cec2353deb9d 100644
--- a/src/scala/spark/SparkContext.scala
+++ b/src/scala/spark/SparkContext.scala
@@ -33,13 +33,17 @@ extends Logging {
 
   // Methods for creating RDDs
 
-  def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) =
+  def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int): RDD[T] =
     new ParallelArray[T](this, seq, numSlices)
 
-  def parallelize[T: ClassManifest](seq: Seq[T]): ParallelArray[T] =
+  def parallelize[T: ClassManifest](seq: Seq[T]): RDD[T] =
     parallelize(seq, scheduler.numCores)
 
-  def textFile(path: String) = new HdfsTextFile(this, path)
+  def textFile(path: String): RDD[String] =
+    new HdfsTextFile(this, path)
+
+  def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] =
+    new UnionRDD(this, rdds)
 
   // Methods for creating shared variables