diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index c1495d531714a8c38b3309f36096b165dceb6b0c..84626df553a38de903b9968ead04e1e886a28c7d 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -403,17 +403,6 @@ private object Utils extends Logging { hostPortParseResults.get(hostPort) } - def addIfNoPort(hostPort: String, port: Int): String = { - if (port <= 0) throw new IllegalArgumentException("Invalid port specified " + port) - - // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now. - // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 - val indx: Int = hostPort.lastIndexOf(':') - if (-1 != indx) return hostPort - - hostPort + ":" + port - } - private[spark] val daemonThreadFactory: ThreadFactory = new ThreadFactoryBuilder().setDaemon(true).build() diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 18d105e0a43c5f9920a633413669c0429398a686..f1c6266bacb49320709c2a7ebf814cefefd4ffd2 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -516,9 +516,16 @@ private[spark] class TaskSetManager( logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( tid, info.duration, tasksFinished, numTasks)) // Deserialize task result and pass it to the scheduler - val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) - result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + try { + val result = ser.deserialize[TaskResult[_]](serializedData) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + } catch { + case cnf: ClassNotFoundException => + val loader = Thread.currentThread().getContextClassLoader + throw new SparkException("ClassNotFound with classloader: " + loader, cnf) + case ex => throw ex + } // Mark finished and stop if we've finished all the tasks finished(index) = true if (tasksFinished == numTasks) { diff --git a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala index 95308c728273c1f7d03908ce1729b7fb94acb4cd..1d69d658f7d8a27aea617e883279a88552d019e1 100644 --- a/core/src/main/scala/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/spark/storage/BlockFetcherIterator.scala @@ -124,6 +124,7 @@ object BlockFetcherIterator { protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. + val originalTotalBlocks = _totalBlocks val remoteRequests = new ArrayBuffer[FetchRequest] for ((address, blockInfos) <- blocksByAddress) { if (address == blockManagerId) { @@ -140,8 +141,15 @@ object BlockFetcherIterator { var curBlocks = new ArrayBuffer[(String, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() - curBlocks += ((blockId, size)) - curRequestSize += size + // Skip empty blocks + if (size > 0) { + curBlocks += ((blockId, size)) + curRequestSize += size + } else if (size == 0) { + _totalBlocks -= 1 + } else { + throw new BlockException(blockId, "Negative block size " + size) + } if (curRequestSize >= minRequestSize) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) @@ -155,6 +163,8 @@ object BlockFetcherIterator { } } } + logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + + originalTotalBlocks + " blocks") remoteRequests } @@ -278,53 +288,6 @@ object BlockFetcherIterator { logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) } - override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val originalTotalBlocks = _totalBlocks; - val remoteRequests = new ArrayBuffer[FetchRequest] - for ((address, blockInfos) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockInfos.map(_._1) - } else { - remoteBlockIds ++= blockInfos.map(_._1) - // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val minRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - if (size > 0) { - curBlocks += ((blockId, size)) - curRequestSize += size - } else if (size == 0) { - //here we changes the totalBlocks - _totalBlocks -= 1 - } else { - throw new BlockException(blockId, "Negative block size " + size) - } - if (curRequestSize >= minRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " + - originalTotalBlocks + " blocks") - remoteRequests - } - private var copiers: List[_ <: Thread] = null override def initialize() { diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 9fb7e001badcbf5156a7755fa047b69764a8f586..cd79bd2bdad0ceefd7d8a89c846b77448bd11204 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -52,7 +52,7 @@ private[spark] object AkkaUtils { """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, lifecycleEvents, akkaWriteTimeout)) - val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader) + val actorSystem = ActorSystem(name, akkaConf) // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a // hack because Akka doesn't let you figure out the port through the public API yet. diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 271f4a4e44177323886bb948ab54124c9c61b39a..b967016cf726791b543781a9f42cf8c9607aab71 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -341,6 +341,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(c.count === 10) } + test("zero sized blocks") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + sc = new SparkContext("local-cluster[2,1,512]", "test") + + // 10 partitions from 4 keys + val NUM_BLOCKS = 10 + val a = sc.parallelize(1 to 4, NUM_BLOCKS) + val b = a.map(x => (x, x*2)) + + // NOTE: The default Java serializer doesn't create zero-sized blocks. + // So, use Kryo + val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName) + + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 4) + + val blockSizes = (0 until NUM_BLOCKS).flatMap { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + statuses.map(x => x._2) + } + val nonEmptyBlocks = blockSizes.filter(x => x > 0) + + // We should have at most 4 non-zero sized partitions + assert(nonEmptyBlocks.size <= 4) + } + } object ShuffleSuite {