diff --git a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala index 7a408ac1500ec11e1745119e716ea7b8deb192a2..44822f5143cc24d4847c07dd5c240d10e484ade1 100644 --- a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala +++ b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala @@ -181,9 +181,9 @@ extends Shuffle[K, V, C] with Logging { // Talks to the tracker and receives instruction private def getTrackerSelectedSplit(outputLocs: Array[SplitInfo]): Int = { - val clientSocketToTracker = - new Socket(TrackedCustomParallelLocalFileShuffle.MasterHostAddress, - TrackedCustomParallelLocalFileShuffle.MasterTrackerPort) + val clientSocketToTracker = new Socket( + TrackedCustomParallelLocalFileShuffle.MasterHostAddress, + TrackedCustomParallelLocalFileShuffle.MasterTrackerPort) val oosTracker = new ObjectOutputStream(clientSocketToTracker.getOutputStream) oosTracker.flush() @@ -282,8 +282,7 @@ extends Shuffle[K, V, C] with Logging { TrackedCustomParallelLocalFileShuffle.ReducerLeaving) { // Receive reducerSplitInfo and serverSplitIndex val reducerSplitInfo = - ois.readObject.asInstanceOf[SplitInfo] - + ois.readObject.asInstanceOf[SplitInfo] val serverSplitIndex = ois.readObject.asInstanceOf[Int] // Update stats @@ -433,7 +432,7 @@ extends Shuffle[K, V, C] with Logging { var alreadyRead = 0 var bytesRead = 0 - while (alreadyRead != requestedFileLen) { + while (alreadyRead < requestedFileLen) { bytesRead = isSource.read(recvByteArray, alreadyRead, requestedFileLen - alreadyRead) if (bytesRead > 0) { @@ -568,7 +567,8 @@ extends ShuffleTrackerStrategy with Logging { // TODO: Use of MaxRxConnections instead of MaxTxConnections is // intentional here. MaxTxConnections is per machine whereas // MaxRxConnections is per mapper/reducer. Will have to find a better way. - if (curConnectionsPerLoc(i) < TrackedCustomParallelLocalFileShuffle.MaxRxConnections && + if (curConnectionsPerLoc(i) < + TrackedCustomParallelLocalFileShuffle.MaxRxConnections && totalConnectionsPerLoc(i) < minConnections && !reducerSplitInfo.hasSplitsBitVector.get(i)) { minConnections = totalConnectionsPerLoc(i) @@ -577,8 +577,7 @@ extends ShuffleTrackerStrategy with Logging { } if (splitIndex != -1) { - curConnectionsPerLoc(splitIndex) = - curConnectionsPerLoc(splitIndex) + 1 + curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1 totalConnectionsPerLoc(splitIndex) = totalConnectionsPerLoc(splitIndex) + 1 @@ -593,25 +592,21 @@ extends ShuffleTrackerStrategy with Logging { def deleteReducerFrom(reducerSplitInfo: SplitInfo, serverSplitIndex: Int): Unit = synchronized { - // Decrease number of active connections curConnectionsPerLoc(serverSplitIndex) = curConnectionsPerLoc(serverSplitIndex) - 1 + assert(curConnectionsPerLoc(serverSplitIndex) >= 0) + curConnectionsPerLoc.foreach { i => print ("" + i + " ") } println("") - - // TODO: Remove this once the bug is fixed - if (curConnectionsPerLoc(serverSplitIndex) < 0) { - curConnectionsPerLoc(serverSplitIndex) = 0 - } } } @serializable -case class SplitInfo (val hostAddress: String, val listenPort: Int, +case class SplitInfo(val hostAddress: String, val listenPort: Int, val inputId: Int) { var hasSplits = 0 @@ -761,8 +756,8 @@ object TrackedCustomParallelLocalFileShuffle extends Logging { class ShuffleServer extends Thread with Logging { - var threadPool = - newDaemonFixedThreadPool(CustomParallelLocalFileShuffle.MaxTxConnections) + var threadPool = newDaemonFixedThreadPool( + CustomParallelLocalFileShuffle.MaxTxConnections) var serverSocket: ServerSocket = null