From 924f47dd11dd9b44211372bc7d7960066e26f682 Mon Sep 17 00:00:00 2001
From: Stephen Haberman <stephen@exigencecorp.com>
Date: Sat, 16 Feb 2013 13:38:42 -0600
Subject: [PATCH] Add RDD.subtract.

Instead of reusing the cogroup primitive, this adds a SubtractedRDD
that knows it only needs to keep rdd1's values (per split) in memory.
---
 core/src/main/scala/spark/RDD.scala           |  21 ++++
 .../main/scala/spark/rdd/SubtractedRDD.scala  | 108 ++++++++++++++++++
 core/src/test/scala/spark/ShuffleSuite.scala  |  26 +++++
 3 files changed, 155 insertions(+)
 create mode 100644 core/src/main/scala/spark/rdd/SubtractedRDD.scala

diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index f6e927a989..a4c51a0115 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -30,6 +30,7 @@ import spark.rdd.MapPartitionsRDD
 import spark.rdd.MapPartitionsWithSplitRDD
 import spark.rdd.PipedRDD
 import spark.rdd.SampledRDD
+import spark.rdd.SubtractedRDD
 import spark.rdd.UnionRDD
 import spark.rdd.ZippedRDD
 import spark.storage.StorageLevel
@@ -383,6 +384,26 @@ abstract class RDD[T: ClassManifest](
     filter(f.isDefinedAt).map(f)
   }
 
+  /**
+   * Return an RDD with the elements from `this` that are not in `other`.
+   * 
+   * Uses `this` partitioner/split size, because even if `other` is huge, the resulting
+   * RDD will be <= us.
+   */
+  def subtract(other: RDD[T]): RDD[T] =
+    subtract(other, partitioner.getOrElse(new HashPartitioner(splits.size)))
+
+  /**
+   * Return an RDD with the elements from `this` that are not in `other`.
+   */
+  def subtract(other: RDD[T], numSplits: Int): RDD[T] =
+    subtract(other, new HashPartitioner(numSplits))
+
+  /**
+   * Return an RDD with the elements from `this` that are not in `other`.
+   */
+  def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
+
   /**
    * Reduces the elements of this RDD using the specified commutative and associative binary operator.
    */
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
new file mode 100644
index 0000000000..244874e4e0
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -0,0 +1,108 @@
+package spark.rdd
+
+import java.util.{HashSet => JHashSet}
+import scala.collection.JavaConversions._
+import spark.RDD
+import spark.Partitioner
+import spark.Dependency
+import spark.TaskContext
+import spark.Split
+import spark.SparkEnv
+import spark.ShuffleDependency
+import spark.OneToOneDependency
+
+/**
+ * An optimized version of cogroup for set difference/subtraction.
+ *
+ * It is possible to implement this operation with just `cogroup`, but
+ * that is less efficient because all of the entries from `rdd2`, for
+ * both matching and non-matching values in `rdd1`, are kept in the
+ * JHashMap until the end.
+ *
+ * With this implementation, only the entries from `rdd1` are kept in-memory,
+ * and the entries from `rdd2` are essentially streamed, as we only need to
+ * touch each once to decide if the value needs to be removed.
+ *
+ * This is particularly helpful when `rdd1` is much smaller than `rdd2`, as
+ * you can use `rdd1`'s partitioner/split size and not worry about running
+ * out of memory because of the size of `rdd2`.
+ */
+private[spark] class SubtractedRDD[T: ClassManifest](
+    @transient var rdd1: RDD[T],
+    @transient var rdd2: RDD[T],
+    part: Partitioner) extends RDD[T](rdd1.context, Nil) {
+
+  override def getDependencies: Seq[Dependency[_]] = {
+    Seq(rdd1, rdd2).map { rdd =>
+      if (rdd.partitioner == Some(part)) {
+        logInfo("Adding one-to-one dependency with " + rdd)
+        new OneToOneDependency(rdd)
+      } else {
+        logInfo("Adding shuffle dependency with " + rdd)
+        val mapSideCombinedRDD = rdd.mapPartitions(i => {
+          val set = new JHashSet[T]()
+          while (i.hasNext) {
+            set.add(i.next)
+          }
+          set.iterator
+        }, true)
+        // ShuffleDependency requires a tuple (k, v), which it will partition by k.
+        // We need this to partition to map to the same place as the k for
+        // OneToOneDependency, which means:
+        // - for already-tupled RDD[(A, B)], into getPartition(a)
+        // - for non-tupled RDD[C], into getPartition(c)
+        val part2 = new Partitioner() {
+          def numPartitions = part.numPartitions
+          def getPartition(key: Any) = key match {
+            case (k, v) => part.getPartition(k)
+            case k => part.getPartition(k)
+          }
+        }
+        new ShuffleDependency(mapSideCombinedRDD.map((_, null)), part2)
+      }
+    }
+  }
+
+  override def getSplits: Array[Split] = {
+    val array = new Array[Split](part.numPartitions)
+    for (i <- 0 until array.size) {
+      // Each CoGroupSplit will dependend on rdd1 and rdd2
+      array(i) = new CoGroupSplit(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
+        dependencies(j) match {
+          case s: ShuffleDependency[_, _] =>
+            new ShuffleCoGroupSplitDep(s.shuffleId)
+          case _ =>
+            new NarrowCoGroupSplitDep(rdd, i, rdd.splits(i))
+        }
+      }.toList)
+    }
+    array
+  }
+
+  override val partitioner = Some(part)
+
+  override def compute(s: Split, context: TaskContext): Iterator[T] = {
+    val split = s.asInstanceOf[CoGroupSplit]
+    val set = new JHashSet[T]
+    def integrate(dep: CoGroupSplitDep, op: T => Unit) = dep match {
+      case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+        for (k <- rdd.iterator(itsSplit, context))
+          op(k.asInstanceOf[T])
+      case ShuffleCoGroupSplitDep(shuffleId) =>
+        for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, split.index))
+          op(k.asInstanceOf[T])
+    }
+    // the first dep is rdd1; add all keys to the set
+    integrate(split.deps(0), set.add)
+    // the second dep is rdd2; remove all of its keys from the set
+    integrate(split.deps(1), set.remove)
+    set.iterator
+  }
+
+  override def clearDependencies() {
+    super.clearDependencies()
+    rdd1 = null
+    rdd2 = null
+  }
+
+}
\ No newline at end of file
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 3493b9511f..367083eab3 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -211,6 +211,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
     assert(rdd.keys.collect().toList === List(1, 2))
     assert(rdd.values.collect().toList === List("a", "b"))
   }
+
+  test("subtract") {
+    sc = new SparkContext("local", "test")
+    val a = sc.parallelize(Array(1, 2, 3), 2)
+    val b = sc.parallelize(Array(2, 3, 4), 4)
+    val c = a.subtract(b)
+    assert(c.collect().toSet === Set(1))
+    assert(c.splits.size === a.splits.size)
+  }
+
+  test("subtract with narrow dependency") {
+    sc = new SparkContext("local", "test")
+    // use a deterministic partitioner
+    val p = new Partitioner() {
+      def numPartitions = 5
+      def getPartition(key: Any) = key.asInstanceOf[Int]
+    }
+    // partitionBy so we have a narrow dependency
+    val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+    println(sc.runJob(a, (i: Iterator[(Int, String)]) => i.toList).toList)
+    // more splits/no partitioner so a shuffle dependency 
+    val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+    val c = a.subtract(b)
+    assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+    assert(c.partitioner.get === p)
+  }
 }
 
 object ShuffleSuite {
-- 
GitLab