Skip to content
Snippets Groups Projects
Commit 3b0cd173 authored by Shivaram Venkataraman's avatar Shivaram Venkataraman
Browse files

Merge branch 'master' of git://github.com/mesos/spark

Conflicts:
	core/src/test/scala/spark/ShuffleSuite.scala
parents 19fd6d54 8cb81782
No related branches found
No related tags found
No related merge requests found
......@@ -403,17 +403,6 @@ private object Utils extends Logging {
hostPortParseResults.get(hostPort)
}
def addIfNoPort(hostPort: String, port: Int): String = {
if (port <= 0) throw new IllegalArgumentException("Invalid port specified " + port)
// This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now.
// For now, we assume that if port exists, then it is valid - not check if it is an int > 0
val indx: Int = hostPort.lastIndexOf(':')
if (-1 != indx) return hostPort
hostPort + ":" + port
}
private[spark] val daemonThreadFactory: ThreadFactory =
new ThreadFactoryBuilder().setDaemon(true).build()
......
......@@ -516,9 +516,16 @@ private[spark] class TaskSetManager(
logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
tid, info.duration, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
try {
val result = ser.deserialize[TaskResult[_]](serializedData)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
} catch {
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread().getContextClassLoader
throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
case ex => throw ex
}
// Mark finished and stop if we've finished all the tasks
finished(index) = true
if (tasksFinished == numTasks) {
......
......@@ -124,6 +124,7 @@ object BlockFetcherIterator {
protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val originalTotalBlocks = _totalBlocks
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
......@@ -140,8 +141,15 @@ object BlockFetcherIterator {
var curBlocks = new ArrayBuffer[(String, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
curBlocks += ((blockId, size))
curRequestSize += size
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
curRequestSize += size
} else if (size == 0) {
_totalBlocks -= 1
} else {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= minRequestSize) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
......@@ -155,6 +163,8 @@ object BlockFetcherIterator {
}
}
}
logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
originalTotalBlocks + " blocks")
remoteRequests
}
......@@ -278,53 +288,6 @@ object BlockFetcherIterator {
logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
}
override protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val originalTotalBlocks = _totalBlocks;
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
localBlockIds ++= blockInfos.map(_._1)
} else {
remoteBlockIds ++= blockInfos.map(_._1)
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(String, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
if (size > 0) {
curBlocks += ((blockId, size))
curRequestSize += size
} else if (size == 0) {
//here we changes the totalBlocks
_totalBlocks -= 1
} else {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= minRequestSize) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curRequestSize = 0
curBlocks = new ArrayBuffer[(String, Long)]
}
}
// Add in the final request
if (!curBlocks.isEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
originalTotalBlocks + " blocks")
remoteRequests
}
private var copiers: List[_ <: Thread] = null
override def initialize() {
......
......@@ -52,7 +52,7 @@ private[spark] object AkkaUtils {
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
lifecycleEvents, akkaWriteTimeout))
val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
val actorSystem = ActorSystem(name, akkaConf)
// Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
// hack because Akka doesn't let you figure out the port through the public API yet.
......
......@@ -341,6 +341,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(c.count === 10)
}
test("zero sized blocks") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test")
// 10 partitions from 4 keys
val NUM_BLOCKS = 10
val a = sc.parallelize(1 to 4, NUM_BLOCKS)
val b = a.map(x => (x, x*2))
// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
val c = new ShuffledRDD(b, new HashPartitioner(10), classOf[spark.KryoSerializer].getName)
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
statuses.map(x => x._2)
}
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
// We should have at most 4 non-zero sized partitions
assert(nonEmptyBlocks.size <= 4)
}
}
object ShuffleSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment