Skip to content
Snippets Groups Projects
Commit 43288732 authored by Stephen Haberman's avatar Stephen Haberman
Browse files

Add assertion about dependencies.

parent c34b8ad2
No related branches found
No related tags found
No related merge requests found
......@@ -62,7 +62,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (Option(partitioner) == self.partitioner) {
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
......
package spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
......@@ -105,11 +106,20 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
def numPartitions = 2
def getPartition(key: Any) = key.asInstanceOf[Int]
}
val pairs = rddToPairRDDFunctions(sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1)))).partitionBy(p)
val sums = pairs.reduceByKey(p, _+_)
println(sums.toDebugString)
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
val sums = pairs.reduceByKey(_+_)
assert(sums.collect().toSet === Set((1, 4), (0, 1)))
assert(sums.partitioner === Some(p))
// count the dependencies to make sure there is only 1 ShuffledRDD
val deps = new HashSet[RDD[_]]()
def visit(r: RDD[_]) {
for (dep <- r.dependencies) {
deps += dep.rdd
visit(dep.rdd)
}
}
visit(sums)
assert(deps.size === 2) // ShuffledRDD, ParallelCollection
}
test("join") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment