diff --git a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala index 6914c40e684fa69d1e3406b150316480c524cf91..5b14f1e2d1916a5de4af6d64e06a3511121c4806 100644 --- a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala +++ b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala @@ -507,8 +507,8 @@ extends Shuffle[K, V, C] with Logging { logInfo("sendLeavingNotification had a " + e) } } finally { - oosTracker.close() oisTracker.close() + oosTracker.close() clientSocketToTracker.close() } } @@ -541,36 +541,40 @@ trait ShuffleTrackerStrategy { class BalanceConnectionsShuffleTrackerStrategy extends ShuffleTrackerStrategy with Logging { var outputLocs: Array[SplitInfo] = null - var numConnectionsPerLoc: Array[Int] = null + var curConnectionsPerLoc: Array[Int] = null + var totalConnectionsPerLoc: Array[Int] = null def initialize(outputLocs_ : Array[SplitInfo]): Unit = { outputLocs = outputLocs_ // Now initialize other data structures - numConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0) + curConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0) + totalConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0) } def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int = synchronized { var minConnections = Int.MaxValue var splitIndex = -1 - for (i <- 0 until numConnectionsPerLoc.size) { + for (i <- 0 until curConnectionsPerLoc.size) { // TODO: Use of MaxRxConnections instead of MaxTxConnections is // intentional here. MaxTxConnections is per machine whereas // MaxRxConnections is per mapper/reducer. Will have to find a better way. - if (numConnectionsPerLoc(i) < TrackedCustomParallelLocalFileShuffle.MaxRxConnections && - numConnectionsPerLoc(i) < minConnections && + if (curConnectionsPerLoc(i) < TrackedCustomParallelLocalFileShuffle.MaxRxConnections && + totalConnectionsPerLoc(i) < minConnections && !reducerSplitInfo.hasSplitsBitVector.get(i)) { - minConnections = numConnectionsPerLoc(i) + minConnections = totalConnectionsPerLoc(i) splitIndex = i } } if (splitIndex != -1) { - numConnectionsPerLoc(splitIndex) = - numConnectionsPerLoc(splitIndex) + 1 + curConnectionsPerLoc(splitIndex) = + curConnectionsPerLoc(splitIndex) + 1 + totalConnectionsPerLoc(splitIndex) = + totalConnectionsPerLoc(splitIndex) + 1 - numConnectionsPerLoc.foreach { i => + totalConnectionsPerLoc.foreach { i => print ("" + i + " ") } println("") @@ -581,16 +585,20 @@ extends ShuffleTrackerStrategy with Logging { def deleteReducerFrom(reducerSplitInfo: SplitInfo, serverSplitIndex: Int): Unit = synchronized { - assert(numConnectionsPerLoc(serverSplitIndex) > 0) // Decrease number of active connections - numConnectionsPerLoc(serverSplitIndex) = - numConnectionsPerLoc(serverSplitIndex) - 1 + curConnectionsPerLoc(serverSplitIndex) = + curConnectionsPerLoc(serverSplitIndex) - 1 - numConnectionsPerLoc.foreach { i => + totalConnectionsPerLoc.foreach { i => print ("" + i + " ") } println("") + + // TODO: Remove this once the bug is fixed + if (curConnectionsPerLoc(serverSplitIndex) < 0) { + curConnectionsPerLoc(serverSplitIndex) = 0 + } } }