Skip to content
Snippets Groups Projects
Commit 00a11304 authored by Reynold Xin's avatar Reynold Xin
Browse files

Added mapSideCombine flag to CoGroupedRDD. Added unit test for

CoGroupedRDD.
parent c1e9cdc4
No related branches found
No related tags found
No related merge requests found
......@@ -2,10 +2,11 @@ package spark.rdd
import java.io.{ObjectOutputStream, IOException}
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Partition, TaskContext}
import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
......@@ -28,7 +29,8 @@ private[spark] case class NarrowCoGroupSplitDep(
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Partition with Serializable {
class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep])
extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
}
......@@ -40,7 +42,19 @@ private[spark] class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
/**
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
*
* @param rdds parent RDDs.
* @param part partitioner used to partition the shuffle output.
* @param mapSideCombine flag indicating whether to merge values before shuffle step.
*/
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
val mapSideCombine: Boolean = true)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
private val aggr = new CoGroupAggregator
......@@ -52,8 +66,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
if (mapSideCombine) {
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
} else {
new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part)
}
}
}
}
......@@ -82,6 +100,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
val seq = map.get(k)
if (seq != null) {
......@@ -92,6 +111,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
seq
}
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
......@@ -102,9 +122,16 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
val fetchItr = fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics)
for ((k, vs) <- fetchItr) {
getSeq(k)(depNum) ++= vs
if (mapSideCombine) {
// With map side combine on, for each key, the shuffle fetcher returns a list of values.
fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics).foreach {
case (key, values) => getSeq(key)(depNum) ++= values
}
} else {
// With map side combine off, for each key the shuffle fetcher returns a single value.
fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics).foreach {
case (key, value) => getSeq(key)(depNum) += value
}
}
}
}
......
......@@ -3,7 +3,7 @@ package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
import spark.SparkContext._
import spark.rdd.{CoalescedRDD, PartitionPruningRDD}
import spark.rdd.{CoalescedRDD, CoGroupedRDD, PartitionPruningRDD}
class RDDSuite extends FunSuite with LocalSparkContext {
......@@ -123,6 +123,36 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
test("cogrouped RDDs") {
sc = new SparkContext("local", "test")
val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2)
val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2)
// Use cogroup function
val cogrouped = rdd1.cogroup(rdd2).collectAsMap()
assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1")))
assert(cogrouped(2) === (Seq("two"), Seq("two1")))
assert(cogrouped(3) === (Seq("three"), Seq()))
// Construct CoGroupedRDD directly, with map side combine enabled
val cogrouped1 = new CoGroupedRDD[Int](
Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
new HashPartitioner(3),
true).collectAsMap()
assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1")))
assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq()))
// Construct CoGroupedRDD directly, with map side combine disabled
val cogrouped2 = new CoGroupedRDD[Int](
Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
new HashPartitioner(3),
false).collectAsMap()
assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1")))
assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq()))
}
test("coalesced RDDs") {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment