diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala index c9411f4208d15cb941d81a94b47cae1db8928135..3df2398b6c5f0cecff5cf042810300b8a430d194 100644 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -87,9 +87,12 @@ private trait DAGScheduler extends Scheduler with Logging { } def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = { - // Kind of ugly: need to register RDDs with the cache here since - // we can't do it in its constructor because # of splits is unknown + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of splits is unknown cacheTracker.registerRDD(rdd.id, rdd.splits.size) + if (shuffleDep != None) { + mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) + } val id = newStageId() val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd)) idToStage(id) = stage diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 30649367586b22ec51b30c567e84be8b0f1e8e10..a183bf80faf457d728bef95f689bce001412cdbe 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -51,13 +51,16 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { val port = System.getProperty("spark.master.port").toInt trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker) } + + def registerShuffle(shuffleId: Int, numMaps: Int) { + if (serverUris.get(shuffleId) != null) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } + serverUris.put(shuffleId, new Array[String](numMaps)) + } - def registerMapOutput(shuffleId: Int, numMaps: Int, mapId: Int, serverUri: String) { + def registerMapOutput(shuffleId: Int, mapId: Int, serverUri: String) { var array = serverUris.get(shuffleId) - if (array == null) { - array = Array.fill[String](numMaps)(null) - serverUris.put(shuffleId, array) - } array.synchronized { array(mapId) = serverUri } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index d14d313b4876f520d48913ce2e99d5949a3f9d5f..c61cb90f826678a3c5ae070ef3a7a48ec514ee39 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -5,9 +5,13 @@ import org.scalatest.prop.Checkers import org.scalacheck.Arbitrary._ import org.scalacheck.Gen import org.scalacheck.Prop._ -import SparkContext._ + +import com.google.common.io.Files + import scala.collection.mutable.ArrayBuffer +import SparkContext._ + class ShuffleSuite extends FunSuite { test("groupByKey") { val sc = new SparkContext("local", "test") @@ -186,4 +190,14 @@ class ShuffleSuite extends FunSuite { sc.stop() } + test("zero-partition RDD") { + val sc = new SparkContext("local", "test") + val emptyDir = Files.createTempDir() + val file = sc.textFile(emptyDir.getAbsolutePath) + assert(file.splits.size == 0) + assert(file.collect().toList === Nil) + // Test that a shuffle on the file works, because this used to be a bug + assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) + sc.stop() + } }