diff --git a/conf/java-opts b/conf/java-opts index 63c44f76395ba4610416a44d6fdeada85dc2697d..702e3b01f5e931e696a81f56eb8db5766697e72a 100644 --- a/conf/java-opts +++ b/conf/java-opts @@ -1 +1 @@ --Dspark.shuffle.class=spark.TrackedCustomBlockedLocalFileShuffle -Dspark.shuffle.masterHostAddress=127.0.0.1 -Dspark.shuffle.masterTrackerPort=22222 -Dspark.shuffle.trackerStrategy=spark.BalanceConnectionsShuffleTrackerStrategy -Dspark.shuffle.maxRxConnections=2 -Dspark.shuffle.maxTxConnections=2 -Dspark.shuffle.blockSize=4096 -Dspark.shuffle.minKnockInterval=100 -Dspark.shuffle.maxKnockInterval=2000 -Dspark.shuffle.maxChatTime=500 +-Dspark.shuffle.class=spark.TrackedCustomBlockedLocalFileShuffle -Dspark.shuffle.masterHostAddress=127.0.0.1 -Dspark.shuffle.masterTrackerPort=22222 -Dspark.shuffle.trackerStrategy=spark.BalanceRemainingShuffleTrackerStrategy -Dspark.shuffle.maxRxConnections=2 -Dspark.shuffle.maxTxConnections=2 -Dspark.shuffle.blockSize=4096 -Dspark.shuffle.minKnockInterval=100 -Dspark.shuffle.maxKnockInterval=2000 -Dspark.shuffle.maxChatTime=10 -Dspark.shuffle.throttleFraction=2.0 diff --git a/src/scala/spark/Shuffle.scala b/src/scala/spark/Shuffle.scala index be78bf9a3c4aa858abf96056e683c135137d4640..0a547fb71509d01408b7d81b14a6d59cbfdf16fe 100644 --- a/src/scala/spark/Shuffle.scala +++ b/src/scala/spark/Shuffle.scala @@ -49,6 +49,11 @@ extends Logging { "spark.shuffle.maxChatTime", "250").toInt private var MaxChatBlocks_ = System.getProperty( "spark.shuffle.maxChatBlocks", "1024").toInt + + // A reducer is throttled if it is this much faster + private var ThrottleFraction_ = System.getProperty( + "spark.shuffle.throttleFraction", "2.0").toDouble + def MasterHostAddress = MasterHostAddress_ def MasterTrackerPort = MasterTrackerPort_ @@ -64,6 +69,8 @@ extends Logging { def MaxChatTime = MaxChatTime_ def MaxChatBlocks = MaxChatBlocks_ + def ThrottleFraction = ThrottleFraction_ + // Returns a standard ThreadFactory except all threads are daemons private def newDaemonThreadFactory: ThreadFactory = { new ThreadFactory { @@ -98,10 +105,15 @@ extends Logging { @serializable case class SplitInfo(val hostAddress: String, val listenPort: Int, - val inputId: Int) { + val splitId: Int) { var hasSplits = 0 var hasSplitsBitVector: BitSet = null + + // Used by mappers of dim |numOutputSplits| + var totalBlocksPerOutputSplit: Array[Int] = null + // Used by reducers of dim |numInputSplits| + var hasBlocksPerInputSplit: Array[Int] = null } object SplitInfo { diff --git a/src/scala/spark/ShuffleTrackerStrategy.scala b/src/scala/spark/ShuffleTrackerStrategy.scala index 021c69557f488115645ac0a68bf6eb93d56936d9..8c1d3b00d9ecb4400b897366ee764594aa23fd24 100644 --- a/src/scala/spark/ShuffleTrackerStrategy.scala +++ b/src/scala/spark/ShuffleTrackerStrategy.scala @@ -1,5 +1,7 @@ package spark +import scala.util.Sorting._ + /** * A trait for implementing tracker strategies for the shuffle system. */ @@ -12,19 +14,25 @@ trait ShuffleTrackerStrategy { // A reducer is done. Update internal stats def deleteReducerFrom(reducerSplitInfo: SplitInfo, - serverSplitIndex: Int): Unit + receptionStat: ReceptionStats): Unit } +/** + * Helper class to send back reception stats from the reducer + */ +case class ReceptionStats(val bytesReceived: Int, val timeSpent: Int, + serverSplitIndex: Int) { } + /** * A simple ShuffleTrackerStrategy that tries to balance the total number of * connections created for each mapper. */ class BalanceConnectionsShuffleTrackerStrategy extends ShuffleTrackerStrategy with Logging { - var numSources = -1 - var outputLocs: Array[SplitInfo] = null - var curConnectionsPerLoc: Array[Int] = null - var totalConnectionsPerLoc: Array[Int] = null + private var numSources = -1 + private var outputLocs: Array[SplitInfo] = null + private var curConnectionsPerLoc: Array[Int] = null + private var totalConnectionsPerLoc: Array[Int] = null // The order of elements in the outputLocs (splitIndex) is used to pass // information back and forth between the tracker, mappers, and reducers @@ -57,27 +65,154 @@ extends ShuffleTrackerStrategy with Logging { curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1 totalConnectionsPerLoc(splitIndex) = totalConnectionsPerLoc(splitIndex) + 1 - - curConnectionsPerLoc.foreach { i => - print ("" + i + " ") - } - println("") } return splitIndex } def deleteReducerFrom(reducerSplitInfo: SplitInfo, - serverSplitIndex: Int): Unit = synchronized { + receptionStat: ReceptionStats): Unit = synchronized { // Decrease number of active connections - curConnectionsPerLoc(serverSplitIndex) = - curConnectionsPerLoc(serverSplitIndex) - 1 + curConnectionsPerLoc(receptionStat.serverSplitIndex) = + curConnectionsPerLoc(receptionStat.serverSplitIndex) - 1 + + assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) + } +} + +/** + * Shuffle tracker strategy that tries to balance the percentage of blocks + * remaining for each reducer + */ +class BalanceRemainingShuffleTrackerStrategy +extends ShuffleTrackerStrategy with Logging { + // Number of mappers + private var numMappers = -1 + // Number of reducers + private var numReducers = -1 + private var outputLocs: Array[SplitInfo] = null + + // Data structures from reducers' perspectives + private var totalBlocksPerInputSplit: Array[Array[Int]] = null + private var hasBlocksPerInputSplit: Array[Array[Int]] = null + + // Stored in bytes per millisecond + private var speedPerInputSplit: Array[Array[Int]] = null + + private var curConnectionsPerLoc: Array[Int] = null + private var totalConnectionsPerLoc: Array[Int] = null + + // The order of elements in the outputLocs (splitIndex) is used to pass + // information back and forth between the tracker, mappers, and reducers + def initialize(outputLocs_ : Array[SplitInfo]): Unit = { + outputLocs = outputLocs_ + + numMappers = outputLocs.size + + // All the outputLocs have totalBlocksPerOutputSplit of same size + numReducers = outputLocs(0).totalBlocksPerOutputSplit.size + + // Now initialize the data structures + totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) => + outputLocs(j).totalBlocksPerOutputSplit(i)) + hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0) + + // Initialize to -1 + speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1) + + curConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0) + totalConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0) + } + + def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int = synchronized { + var splitIndex = -1 + + // Estimate time remaining to finish receiving for all reducer/mapper pairs + var individualEstimates = Array.tabulate(numReducers, numMappers)((i,j) => + (totalBlocksPerInputSplit(i)(j) - hasBlocksPerInputSplit(i)(j)) * + Shuffle.BlockSize / + speedPerInputSplit(i)(j)) + + println("reducerSplitInfo = " + reducerSplitInfo.splitId) + + for (i <- 0 until numReducers) { + for (j <- 0 until numMappers) { + print(individualEstimates(i)(j) + " ") + } + println("") + } - assert(curConnectionsPerLoc(serverSplitIndex) >= 0) + // Estimate time remaining to finish receiving for each reducer + var completionEstimates = Array.tabulate(numReducers)( + individualEstimates(_).foldLeft(Int.MinValue)(Math.max(_,_))) - curConnectionsPerLoc.foreach { i => - print ("" + i + " ") + for (i <- 0 until numReducers) { + print(completionEstimates(i) + " ") } println("") + + // Check if all individualEstimates entries have non-zero values + var estimationComplete = true + for (i <- 0 until numReducers; j <- 0 until numMappers) { + if (individualEstimates(i)(j) < 0) { + estimationComplete = false + } + } + + // Take this reducers estimate out + val myCompletionEstimate = completionEstimates(reducerSplitInfo.splitId) + + // Sort everyone's time + quickSort(completionEstimates) + + // If this reducer is going to complete F times faster than the 2nd + // fastest one just block this one for a while + // TODO: Must be able to support group division instead of singling one out + // TODO: Must have a endGame fraction + if (estimationComplete && numReducers > 1 && + Shuffle.ThrottleFraction * myCompletionEstimate < + completionEstimates(1)) { + splitIndex = -1 + println("Throttling reducer-" + reducerSplitInfo.splitId) + } else { + var minConnections = Int.MaxValue + for (i <- 0 until numMappers) { + // TODO: Use of MaxRxConnections instead of MaxTxConnections is + // intentional here. MaxTxConnections is per machine whereas + // MaxRxConnections is per mapper/reducer. Will have to find a better way. + if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections && + totalConnectionsPerLoc(i) < minConnections && + !reducerSplitInfo.hasSplitsBitVector.get(i)) { + minConnections = totalConnectionsPerLoc(i) + splitIndex = i + } + } + } + + if (splitIndex != -1) { + curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1 + totalConnectionsPerLoc(splitIndex) = + totalConnectionsPerLoc(splitIndex) + 1 + } + + return splitIndex + } + + def deleteReducerFrom(reducerSplitInfo: SplitInfo, + receptionStat: ReceptionStats): Unit = synchronized { + // Update hasBlocksPerInputSplit for reducerSplitInfo + hasBlocksPerInputSplit(reducerSplitInfo.splitId) = + reducerSplitInfo.hasBlocksPerInputSplit + + // Store the last known speed. Add 1 to avoid divide-by-zero. + // TODO: We are forgetting the old speed. Can use averaging at some point. + speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) = + receptionStat.bytesReceived / (receptionStat.timeSpent + 1) + + // Update current connections to the mapper + curConnectionsPerLoc(receptionStat.serverSplitIndex) = + curConnectionsPerLoc(receptionStat.serverSplitIndex) - 1 + + assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0) } } diff --git a/src/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala b/src/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala index 36049509ddbb91e9e158f3f535abb8d791a7834b..ae15b61a475a96a24551529645c7c28d3b700bde 100644 --- a/src/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala +++ b/src/scala/spark/TrackedCustomBlockedLocalFileShuffle.scala @@ -72,6 +72,9 @@ extends Shuffle[K, V, C] with Logging { } } + // Keep track of number of blocks for each output split + var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0) + for (i <- 0 until numOutputSplits) { var blockNum = 0 var isDirty = false @@ -122,10 +125,16 @@ extends Shuffle[K, V, C] with Logging { out = new ObjectOutputStream(new FileOutputStream(file)) out.writeObject(blockNum) out.close() + + // Store number of blocks for this outputSplit + numBlocksPerOutputSplit(i) = blockNum } - (SplitInfo (TrackedCustomBlockedLocalFileShuffle.serverAddress, - TrackedCustomBlockedLocalFileShuffle.serverPort, myIndex)) + var retVal = SplitInfo(TrackedCustomBlockedLocalFileShuffle.serverAddress, + TrackedCustomBlockedLocalFileShuffle.serverPort, myIndex) + retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit + + (retVal) }).collect() // Start tracker @@ -159,15 +168,15 @@ extends Shuffle[K, V, C] with Logging { while (hasSplits < totalSplits && numThreadsToCreate > 0) { // Receive which split to pull from the tracker - val splitIndex = getTrackerSelectedSplit(outputLocs) + val splitIndex = getTrackerSelectedSplit(myId) if (splitIndex != -1) { val selectedSplitInfo = outputLocs(splitIndex) val requestSplit = - "%d/%d/%d".format(shuffleId, selectedSplitInfo.inputId, myId) + "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId) threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, - requestSplit)) + requestSplit, myId)) // splitIndex is in transit. Will be unset in the ShuffleClient splitsInRequestBitVector.synchronized { @@ -202,17 +211,25 @@ extends Shuffle[K, V, C] with Logging { }) } - private def getLocalSplitInfo: SplitInfo = { + private def getLocalSplitInfo(myId: Int): SplitInfo = { var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, - SplitInfo.UnusedParam, SplitInfo.UnusedParam) - + SplitInfo.UnusedParam, myId) + + // Store hasSplits localSplitInfo.hasSplits = hasSplits + // Store hasSplitsBitVector hasSplitsBitVector.synchronized { localSplitInfo.hasSplitsBitVector = hasSplitsBitVector.clone.asInstanceOf[BitSet] } + // Store hasBlocksInSplit to hasBlocksPerInputSplit + hasBlocksInSplit.synchronized { + localSplitInfo.hasBlocksPerInputSplit = + hasBlocksInSplit.clone.asInstanceOf[Array[Int]] + } + // Include the splitsInRequest as well splitsInRequestBitVector.synchronized { localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector) @@ -241,9 +258,9 @@ extends Shuffle[K, V, C] with Logging { } // Talks to the tracker and receives instruction - private def getTrackerSelectedSplit(outputLocs: Array[SplitInfo]): Int = { + private def getTrackerSelectedSplit(myId: Int): Int = { // Local status of hasSplitsBitVector and splitsInRequestBitVector - val localSplitInfo = getLocalSplitInfo + val localSplitInfo = getLocalSplitInfo(myId) // DO NOT talk to the tracker if all the required splits are already busy if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { @@ -346,17 +363,20 @@ extends Shuffle[K, V, C] with Logging { } else if (reducerIntention == TrackedCustomBlockedLocalFileShuffle.ReducerLeaving) { - // Receive reducerSplitInfo and serverSplitIndex val reducerSplitInfo = - ois.readObject.asInstanceOf[SplitInfo] - val serverSplitIndex = ois.readObject.asInstanceOf[Int] + ois.readObject.asInstanceOf[SplitInfo] + + // Receive reception stats: how many blocks the reducer + // read in how much time and from where + val receptionStat = + ois.readObject.asInstanceOf[ReceptionStats] // Update stats trackerStrategy.deleteReducerFrom(reducerSplitInfo, - serverSplitIndex) + receptionStat) // Send ACK - oos.writeObject(serverSplitIndex) + oos.writeObject(receptionStat.serverSplitIndex) oos.flush() } else { @@ -427,7 +447,7 @@ extends Shuffle[K, V, C] with Logging { } class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, - requestSplit: String) + requestSplit: String, myId: Int) extends Thread with Logging { private var peerSocketToSource: Socket = null private var oosSource: ObjectOutputStream = null @@ -438,6 +458,10 @@ extends Shuffle[K, V, C] with Logging { // Make sure that multiple messages don't go to the tracker private var alreadySentLeavingNotification = false + // Keep track of bytes received and time spent + private var numBytesReceived = 0 + private var totalTimeSpent = 0 + override def run: Unit = { // Setup the timeout mechanism var timeOutTask = new TimerTask { @@ -534,6 +558,10 @@ extends Shuffle[K, V, C] with Logging { logInfo("END READ: " + requestPath) val readTime = System.currentTimeMillis - readStartTime logInfo("Reading " + requestPath + " took " + readTime + " millis.") + + // Update stats + numBytesReceived = numBytesReceived + requestedFileLen + totalTimeSpent = totalTimeSpent + readTime.toInt } else { throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) } @@ -574,11 +602,12 @@ extends Shuffle[K, V, C] with Logging { oosTracker.flush() // Send reducerSplitInfo - oosTracker.writeObject(getLocalSplitInfo) + oosTracker.writeObject(getLocalSplitInfo(myId)) oosTracker.flush() - // Send serverSplitInfo so that tracker can update its stats - oosTracker.writeObject(splitIndex) + // Send reception stats + oosTracker.writeObject(ReceptionStats( + numBytesReceived, totalTimeSpent, splitIndex)) oosTracker.flush() // Receive ACK. No need to do anything with that diff --git a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala index 1a58155a85847419e6fba243183e61a5fdddc86b..04fbcc59fb7ef1c59781cc5b4a3acded6b11f0c5 100644 --- a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala +++ b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala @@ -96,8 +96,8 @@ extends Shuffle[K, V, C] with Logging { receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] combiners = new HashMap[K, C] - var threadPool = - Shuffle.newDaemonFixedThreadPool(Shuffle.MaxRxConnections) + var threadPool = Shuffle.newDaemonFixedThreadPool( + Shuffle.MaxRxConnections) while (hasSplits < totalSplits) { var numThreadsToCreate = @@ -106,15 +106,15 @@ extends Shuffle[K, V, C] with Logging { while (hasSplits < totalSplits && numThreadsToCreate > 0) { // Receive which split to pull from the tracker - val splitIndex = getTrackerSelectedSplit(outputLocs) + val splitIndex = getTrackerSelectedSplit(myId) if (splitIndex != -1) { val selectedSplitInfo = outputLocs(splitIndex) val requestSplit = - "%d/%d/%d".format(shuffleId, selectedSplitInfo.inputId, myId) + "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId) threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, - requestSplit)) + requestSplit, myId)) // splitIndex is in transit. Will be unset in the ShuffleClient splitsInRequestBitVector.synchronized { @@ -149,9 +149,9 @@ extends Shuffle[K, V, C] with Logging { }) } - private def getLocalSplitInfo: SplitInfo = { + private def getLocalSplitInfo(myId: Int): SplitInfo = { var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, - SplitInfo.UnusedParam, SplitInfo.UnusedParam) + SplitInfo.UnusedParam, myId) localSplitInfo.hasSplits = hasSplits @@ -189,9 +189,9 @@ extends Shuffle[K, V, C] with Logging { } // Talks to the tracker and receives instruction - private def getTrackerSelectedSplit(outputLocs: Array[SplitInfo]): Int = { + private def getTrackerSelectedSplit(myId: Int): Int = { // Local status of hasSplitsBitVector and splitsInRequestBitVector - val localSplitInfo = getLocalSplitInfo + val localSplitInfo = getLocalSplitInfo(myId) // DO NOT talk to the tracker if all the required splits are already busy if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { @@ -294,17 +294,20 @@ extends Shuffle[K, V, C] with Logging { } else if (reducerIntention == TrackedCustomParallelLocalFileShuffle.ReducerLeaving) { - // Receive reducerSplitInfo and serverSplitIndex val reducerSplitInfo = - ois.readObject.asInstanceOf[SplitInfo] - val serverSplitIndex = ois.readObject.asInstanceOf[Int] + ois.readObject.asInstanceOf[SplitInfo] + + // Receive reception stats: how many blocks the reducer + // read in how much time and from where + val receptionStat = + ois.readObject.asInstanceOf[ReceptionStats] // Update stats trackerStrategy.deleteReducerFrom(reducerSplitInfo, - serverSplitIndex) + receptionStat) // Send ACK - oos.writeObject(serverSplitIndex) + oos.writeObject(receptionStat.serverSplitIndex) oos.flush() } else { @@ -375,7 +378,7 @@ extends Shuffle[K, V, C] with Logging { } class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, - requestSplit: String) + requestSplit: String, myId: Int) extends Thread with Logging { private var peerSocketToSource: Socket = null private var oosSource: ObjectOutputStream = null @@ -386,6 +389,10 @@ extends Shuffle[K, V, C] with Logging { // Make sure that multiple messages don't go to the tracker private var alreadySentLeavingNotification = false + // Keep track of bytes received and time spent + private var numBytesReceived = 0 + private var totalTimeSpent = 0 + override def run: Unit = { // Setup the timeout mechanism var timeOutTask = new TimerTask { @@ -467,6 +474,10 @@ extends Shuffle[K, V, C] with Logging { logInfo("END READ: " + requestPath) val readTime = System.currentTimeMillis - readStartTime logInfo("Reading " + requestPath + " took " + readTime + " millis.") + + // Update stats + numBytesReceived = numBytesReceived + requestedFileLen + totalTimeSpent = totalTimeSpent + readTime.toInt } else { throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) } @@ -506,11 +517,12 @@ extends Shuffle[K, V, C] with Logging { oosTracker.flush() // Send reducerSplitInfo - oosTracker.writeObject(getLocalSplitInfo) + oosTracker.writeObject(getLocalSplitInfo(myId)) oosTracker.flush() - // Send serverSplitInfo so that tracker can update its stats - oosTracker.writeObject(splitIndex) + // Send reception stats + oosTracker.writeObject(ReceptionStats( + numBytesReceived, totalTimeSpent, splitIndex)) oosTracker.flush() // Receive ACK. No need to do anything with that