Skip to content
Snippets Groups Projects
Commit e269f6f7 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Register RDDs with the MapOutputTracker even if they have no partitions.

Fixes #105.
parent 5fd101d7
No related branches found
No related tags found
No related merge requests found
......@@ -87,9 +87,12 @@ private trait DAGScheduler extends Scheduler with Logging {
}
def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
val id = newStageId()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
idToStage(id) = stage
......
......@@ -51,13 +51,16 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
val port = System.getProperty("spark.master.port").toInt
trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (serverUris.get(shuffleId) != null) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
serverUris.put(shuffleId, new Array[String](numMaps))
}
def registerMapOutput(shuffleId: Int, numMaps: Int, mapId: Int, serverUri: String) {
def registerMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
var array = serverUris.get(shuffleId)
if (array == null) {
array = Array.fill[String](numMaps)(null)
serverUris.put(shuffleId, array)
}
array.synchronized {
array(mapId) = serverUri
}
......
......@@ -5,9 +5,13 @@ import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
import SparkContext._
import com.google.common.io.Files
import scala.collection.mutable.ArrayBuffer
import SparkContext._
class ShuffleSuite extends FunSuite {
test("groupByKey") {
val sc = new SparkContext("local", "test")
......@@ -186,4 +190,14 @@ class ShuffleSuite extends FunSuite {
sc.stop()
}
test("zero-partition RDD") {
val sc = new SparkContext("local", "test")
val emptyDir = Files.createTempDir()
val file = sc.textFile(emptyDir.getAbsolutePath)
assert(file.splits.size == 0)
assert(file.collect().toList === Nil)
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
sc.stop()
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment