Skip to content
Snippets Groups Projects
Commit 63fe2255 authored by Stephen Haberman's avatar Stephen Haberman
Browse files

Simplify SubtractedRDD in preparation from subtractByKey.

parent cbf8f0d4
No related branches found
No related tags found
No related merge requests found
......@@ -639,6 +639,8 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
}, true)
}
// def subtractByKey(other: RDD[K]): RDD[(K,V)] = subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
}
private[spark]
......
......@@ -408,8 +408,24 @@ abstract class RDD[T: ClassManifest](
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
def subtract(other: RDD[T]): RDD[T] =
subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
def subtract(other: RDD[T]): RDD[T] = {
// If we do have a partitioner, our T is really (K, V), and we'll need to
// unwrap the (T, null) that subtract does to get back to the K
val rdd = subtract(other, partitioner match {
case None => new HashPartitioner(partitions.size)
case Some(p) => new Partitioner() {
override def numPartitions = p.numPartitions
override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
}
})
// Hacky, but if we did have a partitioner, we can keep using it
new RDD[T](rdd) {
override def getPartitions = rdd.partitions
override def getDependencies = rdd.dependencies
override def compute(split: Partition, context: TaskContext) = rdd.compute(split, context)
override val partitioner = RDD.this.partitioner
}
}
/**
* Return an RDD with the elements from `this` that are not in `other`.
......@@ -420,7 +436,9 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
new SubtractedRDD[T, Any](this.map((_, null)), other.map((_, null)), p).keys
}
/**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
......
package spark.rdd
import java.util.{HashSet => JHashSet}
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import spark.RDD
import spark.Partitioner
import spark.Dependency
......@@ -27,39 +28,20 @@ import spark.OneToOneDependency
* you can use `rdd1`'s partitioner/partition size and not worry about running
* out of memory because of the size of `rdd2`.
*/
private[spark] class SubtractedRDD[T: ClassManifest](
@transient var rdd1: RDD[T],
@transient var rdd2: RDD[T],
part: Partitioner) extends RDD[T](rdd1.context, Nil) {
private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest](
@transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, V)],
part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = {
Seq(rdd1, rdd2).map { rdd =>
if (rdd.partitioner == Some(part)) {
logInfo("Adding one-to-one dependency with " + rdd)
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
val mapSideCombinedRDD = rdd.mapPartitions(i => {
val set = new JHashSet[T]()
while (i.hasNext) {
set.add(i.next)
}
set.iterator
}, true)
// ShuffleDependency requires a tuple (k, v), which it will partition by k.
// We need this to partition to map to the same place as the k for
// OneToOneDependency, which means:
// - for already-tupled RDD[(A, B)], into getPartition(a)
// - for non-tupled RDD[C], into getPartition(c)
val part2 = new Partitioner() {
def numPartitions = part.numPartitions
def getPartition(key: Any) = key match {
case (k, v) => part.getPartition(k)
case k => part.getPartition(k)
}
}
new ShuffleDependency(mapSideCombinedRDD.map((_, null)), part2)
}
if (rdd.partitioner == Some(part)) {
logInfo("Adding one-to-one dependency with " + rdd)
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
new ShuffleDependency(rdd, part)
}
}
}
......@@ -81,22 +63,32 @@ private[spark] class SubtractedRDD[T: ClassManifest](
override val partitioner = Some(part)
override def compute(p: Partition, context: TaskContext): Iterator[T] = {
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
val set = new JHashSet[T]
def integrate(dep: CoGroupSplitDep, op: T => Unit) = dep match {
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
if (seq != null) {
seq
} else {
val seq = new ArrayBuffer[V]()
map.put(k, seq)
seq
}
}
def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
for (k <- rdd.iterator(itsSplit, context))
op(k.asInstanceOf[T])
for (t <- rdd.iterator(itsSplit, context))
op(t.asInstanceOf[(K, V)])
case ShuffleCoGroupSplitDep(shuffleId) =>
for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
op(k.asInstanceOf[T])
for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
op(t.asInstanceOf[(K, V)])
}
// the first dep is rdd1; add all keys to the set
integrate(partition.deps(0), set.add)
// the second dep is rdd2; remove all of its keys from the set
integrate(partition.deps(1), set.remove)
set.iterator
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
// the second dep is rdd2; remove all of its keys
integrate(partition.deps(1), t => map.remove(t._1) )
map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten
}
override def clearDependencies() {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment