Skip to content
Snippets Groups Projects
Commit 07e778d7 authored by Mosharaf Chowdhury's avatar Mosharaf Chowdhury
Browse files

- Updated Reducer-Tracker communication protocol.

 - Implemented a new tracker strategy for shuffle where if a reducer is too fast its stalled until other catchup. Basic version is working, but more work is necessary.
parent 33d59fb2
No related branches found
No related tags found
No related merge requests found
-Dspark.shuffle.class=spark.TrackedCustomBlockedLocalFileShuffle -Dspark.shuffle.masterHostAddress=127.0.0.1 -Dspark.shuffle.masterTrackerPort=22222 -Dspark.shuffle.trackerStrategy=spark.BalanceConnectionsShuffleTrackerStrategy -Dspark.shuffle.maxRxConnections=2 -Dspark.shuffle.maxTxConnections=2 -Dspark.shuffle.blockSize=4096 -Dspark.shuffle.minKnockInterval=100 -Dspark.shuffle.maxKnockInterval=2000 -Dspark.shuffle.maxChatTime=500 -Dspark.shuffle.class=spark.TrackedCustomBlockedLocalFileShuffle -Dspark.shuffle.masterHostAddress=127.0.0.1 -Dspark.shuffle.masterTrackerPort=22222 -Dspark.shuffle.trackerStrategy=spark.BalanceRemainingShuffleTrackerStrategy -Dspark.shuffle.maxRxConnections=2 -Dspark.shuffle.maxTxConnections=2 -Dspark.shuffle.blockSize=4096 -Dspark.shuffle.minKnockInterval=100 -Dspark.shuffle.maxKnockInterval=2000 -Dspark.shuffle.maxChatTime=10 -Dspark.shuffle.throttleFraction=2.0
...@@ -49,6 +49,11 @@ extends Logging { ...@@ -49,6 +49,11 @@ extends Logging {
"spark.shuffle.maxChatTime", "250").toInt "spark.shuffle.maxChatTime", "250").toInt
private var MaxChatBlocks_ = System.getProperty( private var MaxChatBlocks_ = System.getProperty(
"spark.shuffle.maxChatBlocks", "1024").toInt "spark.shuffle.maxChatBlocks", "1024").toInt
// A reducer is throttled if it is this much faster
private var ThrottleFraction_ = System.getProperty(
"spark.shuffle.throttleFraction", "2.0").toDouble
def MasterHostAddress = MasterHostAddress_ def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_ def MasterTrackerPort = MasterTrackerPort_
...@@ -64,6 +69,8 @@ extends Logging { ...@@ -64,6 +69,8 @@ extends Logging {
def MaxChatTime = MaxChatTime_ def MaxChatTime = MaxChatTime_
def MaxChatBlocks = MaxChatBlocks_ def MaxChatBlocks = MaxChatBlocks_
def ThrottleFraction = ThrottleFraction_
// Returns a standard ThreadFactory except all threads are daemons // Returns a standard ThreadFactory except all threads are daemons
private def newDaemonThreadFactory: ThreadFactory = { private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory { new ThreadFactory {
...@@ -98,10 +105,15 @@ extends Logging { ...@@ -98,10 +105,15 @@ extends Logging {
@serializable @serializable
case class SplitInfo(val hostAddress: String, val listenPort: Int, case class SplitInfo(val hostAddress: String, val listenPort: Int,
val inputId: Int) { val splitId: Int) {
var hasSplits = 0 var hasSplits = 0
var hasSplitsBitVector: BitSet = null var hasSplitsBitVector: BitSet = null
// Used by mappers of dim |numOutputSplits|
var totalBlocksPerOutputSplit: Array[Int] = null
// Used by reducers of dim |numInputSplits|
var hasBlocksPerInputSplit: Array[Int] = null
} }
object SplitInfo { object SplitInfo {
......
package spark package spark
import scala.util.Sorting._
/** /**
* A trait for implementing tracker strategies for the shuffle system. * A trait for implementing tracker strategies for the shuffle system.
*/ */
...@@ -12,19 +14,25 @@ trait ShuffleTrackerStrategy { ...@@ -12,19 +14,25 @@ trait ShuffleTrackerStrategy {
// A reducer is done. Update internal stats // A reducer is done. Update internal stats
def deleteReducerFrom(reducerSplitInfo: SplitInfo, def deleteReducerFrom(reducerSplitInfo: SplitInfo,
serverSplitIndex: Int): Unit receptionStat: ReceptionStats): Unit
} }
/**
* Helper class to send back reception stats from the reducer
*/
case class ReceptionStats(val bytesReceived: Int, val timeSpent: Int,
serverSplitIndex: Int) { }
/** /**
* A simple ShuffleTrackerStrategy that tries to balance the total number of * A simple ShuffleTrackerStrategy that tries to balance the total number of
* connections created for each mapper. * connections created for each mapper.
*/ */
class BalanceConnectionsShuffleTrackerStrategy class BalanceConnectionsShuffleTrackerStrategy
extends ShuffleTrackerStrategy with Logging { extends ShuffleTrackerStrategy with Logging {
var numSources = -1 private var numSources = -1
var outputLocs: Array[SplitInfo] = null private var outputLocs: Array[SplitInfo] = null
var curConnectionsPerLoc: Array[Int] = null private var curConnectionsPerLoc: Array[Int] = null
var totalConnectionsPerLoc: Array[Int] = null private var totalConnectionsPerLoc: Array[Int] = null
// The order of elements in the outputLocs (splitIndex) is used to pass // The order of elements in the outputLocs (splitIndex) is used to pass
// information back and forth between the tracker, mappers, and reducers // information back and forth between the tracker, mappers, and reducers
...@@ -57,27 +65,154 @@ extends ShuffleTrackerStrategy with Logging { ...@@ -57,27 +65,154 @@ extends ShuffleTrackerStrategy with Logging {
curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1 curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1
totalConnectionsPerLoc(splitIndex) = totalConnectionsPerLoc(splitIndex) =
totalConnectionsPerLoc(splitIndex) + 1 totalConnectionsPerLoc(splitIndex) + 1
curConnectionsPerLoc.foreach { i =>
print ("" + i + " ")
}
println("")
} }
return splitIndex return splitIndex
} }
def deleteReducerFrom(reducerSplitInfo: SplitInfo, def deleteReducerFrom(reducerSplitInfo: SplitInfo,
serverSplitIndex: Int): Unit = synchronized { receptionStat: ReceptionStats): Unit = synchronized {
// Decrease number of active connections // Decrease number of active connections
curConnectionsPerLoc(serverSplitIndex) = curConnectionsPerLoc(receptionStat.serverSplitIndex) =
curConnectionsPerLoc(serverSplitIndex) - 1 curConnectionsPerLoc(receptionStat.serverSplitIndex) - 1
assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
}
}
/**
* Shuffle tracker strategy that tries to balance the percentage of blocks
* remaining for each reducer
*/
class BalanceRemainingShuffleTrackerStrategy
extends ShuffleTrackerStrategy with Logging {
// Number of mappers
private var numMappers = -1
// Number of reducers
private var numReducers = -1
private var outputLocs: Array[SplitInfo] = null
// Data structures from reducers' perspectives
private var totalBlocksPerInputSplit: Array[Array[Int]] = null
private var hasBlocksPerInputSplit: Array[Array[Int]] = null
// Stored in bytes per millisecond
private var speedPerInputSplit: Array[Array[Int]] = null
private var curConnectionsPerLoc: Array[Int] = null
private var totalConnectionsPerLoc: Array[Int] = null
// The order of elements in the outputLocs (splitIndex) is used to pass
// information back and forth between the tracker, mappers, and reducers
def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
outputLocs = outputLocs_
numMappers = outputLocs.size
// All the outputLocs have totalBlocksPerOutputSplit of same size
numReducers = outputLocs(0).totalBlocksPerOutputSplit.size
// Now initialize the data structures
totalBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((i,j) =>
outputLocs(j).totalBlocksPerOutputSplit(i))
hasBlocksPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => 0)
// Initialize to -1
speedPerInputSplit = Array.tabulate(numReducers, numMappers)((_,_) => -1)
curConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0)
totalConnectionsPerLoc = Array.tabulate(numMappers)(_ => 0)
}
def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int = synchronized {
var splitIndex = -1
// Estimate time remaining to finish receiving for all reducer/mapper pairs
var individualEstimates = Array.tabulate(numReducers, numMappers)((i,j) =>
(totalBlocksPerInputSplit(i)(j) - hasBlocksPerInputSplit(i)(j)) *
Shuffle.BlockSize /
speedPerInputSplit(i)(j))
println("reducerSplitInfo = " + reducerSplitInfo.splitId)
for (i <- 0 until numReducers) {
for (j <- 0 until numMappers) {
print(individualEstimates(i)(j) + " ")
}
println("")
}
assert(curConnectionsPerLoc(serverSplitIndex) >= 0) // Estimate time remaining to finish receiving for each reducer
var completionEstimates = Array.tabulate(numReducers)(
individualEstimates(_).foldLeft(Int.MinValue)(Math.max(_,_)))
curConnectionsPerLoc.foreach { i => for (i <- 0 until numReducers) {
print ("" + i + " ") print(completionEstimates(i) + " ")
} }
println("") println("")
// Check if all individualEstimates entries have non-zero values
var estimationComplete = true
for (i <- 0 until numReducers; j <- 0 until numMappers) {
if (individualEstimates(i)(j) < 0) {
estimationComplete = false
}
}
// Take this reducers estimate out
val myCompletionEstimate = completionEstimates(reducerSplitInfo.splitId)
// Sort everyone's time
quickSort(completionEstimates)
// If this reducer is going to complete F times faster than the 2nd
// fastest one just block this one for a while
// TODO: Must be able to support group division instead of singling one out
// TODO: Must have a endGame fraction
if (estimationComplete && numReducers > 1 &&
Shuffle.ThrottleFraction * myCompletionEstimate <
completionEstimates(1)) {
splitIndex = -1
println("Throttling reducer-" + reducerSplitInfo.splitId)
} else {
var minConnections = Int.MaxValue
for (i <- 0 until numMappers) {
// TODO: Use of MaxRxConnections instead of MaxTxConnections is
// intentional here. MaxTxConnections is per machine whereas
// MaxRxConnections is per mapper/reducer. Will have to find a better way.
if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections &&
totalConnectionsPerLoc(i) < minConnections &&
!reducerSplitInfo.hasSplitsBitVector.get(i)) {
minConnections = totalConnectionsPerLoc(i)
splitIndex = i
}
}
}
if (splitIndex != -1) {
curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1
totalConnectionsPerLoc(splitIndex) =
totalConnectionsPerLoc(splitIndex) + 1
}
return splitIndex
}
def deleteReducerFrom(reducerSplitInfo: SplitInfo,
receptionStat: ReceptionStats): Unit = synchronized {
// Update hasBlocksPerInputSplit for reducerSplitInfo
hasBlocksPerInputSplit(reducerSplitInfo.splitId) =
reducerSplitInfo.hasBlocksPerInputSplit
// Store the last known speed. Add 1 to avoid divide-by-zero.
// TODO: We are forgetting the old speed. Can use averaging at some point.
speedPerInputSplit(reducerSplitInfo.splitId)(receptionStat.serverSplitIndex) =
receptionStat.bytesReceived / (receptionStat.timeSpent + 1)
// Update current connections to the mapper
curConnectionsPerLoc(receptionStat.serverSplitIndex) =
curConnectionsPerLoc(receptionStat.serverSplitIndex) - 1
assert(curConnectionsPerLoc(receptionStat.serverSplitIndex) >= 0)
} }
} }
...@@ -72,6 +72,9 @@ extends Shuffle[K, V, C] with Logging { ...@@ -72,6 +72,9 @@ extends Shuffle[K, V, C] with Logging {
} }
} }
// Keep track of number of blocks for each output split
var numBlocksPerOutputSplit = Array.tabulate(numOutputSplits)(_ => 0)
for (i <- 0 until numOutputSplits) { for (i <- 0 until numOutputSplits) {
var blockNum = 0 var blockNum = 0
var isDirty = false var isDirty = false
...@@ -122,10 +125,16 @@ extends Shuffle[K, V, C] with Logging { ...@@ -122,10 +125,16 @@ extends Shuffle[K, V, C] with Logging {
out = new ObjectOutputStream(new FileOutputStream(file)) out = new ObjectOutputStream(new FileOutputStream(file))
out.writeObject(blockNum) out.writeObject(blockNum)
out.close() out.close()
// Store number of blocks for this outputSplit
numBlocksPerOutputSplit(i) = blockNum
} }
(SplitInfo (TrackedCustomBlockedLocalFileShuffle.serverAddress, var retVal = SplitInfo(TrackedCustomBlockedLocalFileShuffle.serverAddress,
TrackedCustomBlockedLocalFileShuffle.serverPort, myIndex)) TrackedCustomBlockedLocalFileShuffle.serverPort, myIndex)
retVal.totalBlocksPerOutputSplit = numBlocksPerOutputSplit
(retVal)
}).collect() }).collect()
// Start tracker // Start tracker
...@@ -159,15 +168,15 @@ extends Shuffle[K, V, C] with Logging { ...@@ -159,15 +168,15 @@ extends Shuffle[K, V, C] with Logging {
while (hasSplits < totalSplits && numThreadsToCreate > 0) { while (hasSplits < totalSplits && numThreadsToCreate > 0) {
// Receive which split to pull from the tracker // Receive which split to pull from the tracker
val splitIndex = getTrackerSelectedSplit(outputLocs) val splitIndex = getTrackerSelectedSplit(myId)
if (splitIndex != -1) { if (splitIndex != -1) {
val selectedSplitInfo = outputLocs(splitIndex) val selectedSplitInfo = outputLocs(splitIndex)
val requestSplit = val requestSplit =
"%d/%d/%d".format(shuffleId, selectedSplitInfo.inputId, myId) "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId)
threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo,
requestSplit)) requestSplit, myId))
// splitIndex is in transit. Will be unset in the ShuffleClient // splitIndex is in transit. Will be unset in the ShuffleClient
splitsInRequestBitVector.synchronized { splitsInRequestBitVector.synchronized {
...@@ -202,17 +211,25 @@ extends Shuffle[K, V, C] with Logging { ...@@ -202,17 +211,25 @@ extends Shuffle[K, V, C] with Logging {
}) })
} }
private def getLocalSplitInfo: SplitInfo = { private def getLocalSplitInfo(myId: Int): SplitInfo = {
var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress,
SplitInfo.UnusedParam, SplitInfo.UnusedParam) SplitInfo.UnusedParam, myId)
// Store hasSplits
localSplitInfo.hasSplits = hasSplits localSplitInfo.hasSplits = hasSplits
// Store hasSplitsBitVector
hasSplitsBitVector.synchronized { hasSplitsBitVector.synchronized {
localSplitInfo.hasSplitsBitVector = localSplitInfo.hasSplitsBitVector =
hasSplitsBitVector.clone.asInstanceOf[BitSet] hasSplitsBitVector.clone.asInstanceOf[BitSet]
} }
// Store hasBlocksInSplit to hasBlocksPerInputSplit
hasBlocksInSplit.synchronized {
localSplitInfo.hasBlocksPerInputSplit =
hasBlocksInSplit.clone.asInstanceOf[Array[Int]]
}
// Include the splitsInRequest as well // Include the splitsInRequest as well
splitsInRequestBitVector.synchronized { splitsInRequestBitVector.synchronized {
localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector) localSplitInfo.hasSplitsBitVector.or(splitsInRequestBitVector)
...@@ -241,9 +258,9 @@ extends Shuffle[K, V, C] with Logging { ...@@ -241,9 +258,9 @@ extends Shuffle[K, V, C] with Logging {
} }
// Talks to the tracker and receives instruction // Talks to the tracker and receives instruction
private def getTrackerSelectedSplit(outputLocs: Array[SplitInfo]): Int = { private def getTrackerSelectedSplit(myId: Int): Int = {
// Local status of hasSplitsBitVector and splitsInRequestBitVector // Local status of hasSplitsBitVector and splitsInRequestBitVector
val localSplitInfo = getLocalSplitInfo val localSplitInfo = getLocalSplitInfo(myId)
// DO NOT talk to the tracker if all the required splits are already busy // DO NOT talk to the tracker if all the required splits are already busy
if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) {
...@@ -346,17 +363,20 @@ extends Shuffle[K, V, C] with Logging { ...@@ -346,17 +363,20 @@ extends Shuffle[K, V, C] with Logging {
} }
else if (reducerIntention == else if (reducerIntention ==
TrackedCustomBlockedLocalFileShuffle.ReducerLeaving) { TrackedCustomBlockedLocalFileShuffle.ReducerLeaving) {
// Receive reducerSplitInfo and serverSplitIndex
val reducerSplitInfo = val reducerSplitInfo =
ois.readObject.asInstanceOf[SplitInfo] ois.readObject.asInstanceOf[SplitInfo]
val serverSplitIndex = ois.readObject.asInstanceOf[Int]
// Receive reception stats: how many blocks the reducer
// read in how much time and from where
val receptionStat =
ois.readObject.asInstanceOf[ReceptionStats]
// Update stats // Update stats
trackerStrategy.deleteReducerFrom(reducerSplitInfo, trackerStrategy.deleteReducerFrom(reducerSplitInfo,
serverSplitIndex) receptionStat)
// Send ACK // Send ACK
oos.writeObject(serverSplitIndex) oos.writeObject(receptionStat.serverSplitIndex)
oos.flush() oos.flush()
} }
else { else {
...@@ -427,7 +447,7 @@ extends Shuffle[K, V, C] with Logging { ...@@ -427,7 +447,7 @@ extends Shuffle[K, V, C] with Logging {
} }
class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo,
requestSplit: String) requestSplit: String, myId: Int)
extends Thread with Logging { extends Thread with Logging {
private var peerSocketToSource: Socket = null private var peerSocketToSource: Socket = null
private var oosSource: ObjectOutputStream = null private var oosSource: ObjectOutputStream = null
...@@ -438,6 +458,10 @@ extends Shuffle[K, V, C] with Logging { ...@@ -438,6 +458,10 @@ extends Shuffle[K, V, C] with Logging {
// Make sure that multiple messages don't go to the tracker // Make sure that multiple messages don't go to the tracker
private var alreadySentLeavingNotification = false private var alreadySentLeavingNotification = false
// Keep track of bytes received and time spent
private var numBytesReceived = 0
private var totalTimeSpent = 0
override def run: Unit = { override def run: Unit = {
// Setup the timeout mechanism // Setup the timeout mechanism
var timeOutTask = new TimerTask { var timeOutTask = new TimerTask {
...@@ -534,6 +558,10 @@ extends Shuffle[K, V, C] with Logging { ...@@ -534,6 +558,10 @@ extends Shuffle[K, V, C] with Logging {
logInfo("END READ: " + requestPath) logInfo("END READ: " + requestPath)
val readTime = System.currentTimeMillis - readStartTime val readTime = System.currentTimeMillis - readStartTime
logInfo("Reading " + requestPath + " took " + readTime + " millis.") logInfo("Reading " + requestPath + " took " + readTime + " millis.")
// Update stats
numBytesReceived = numBytesReceived + requestedFileLen
totalTimeSpent = totalTimeSpent + readTime.toInt
} else { } else {
throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit)
} }
...@@ -574,11 +602,12 @@ extends Shuffle[K, V, C] with Logging { ...@@ -574,11 +602,12 @@ extends Shuffle[K, V, C] with Logging {
oosTracker.flush() oosTracker.flush()
// Send reducerSplitInfo // Send reducerSplitInfo
oosTracker.writeObject(getLocalSplitInfo) oosTracker.writeObject(getLocalSplitInfo(myId))
oosTracker.flush() oosTracker.flush()
// Send serverSplitInfo so that tracker can update its stats // Send reception stats
oosTracker.writeObject(splitIndex) oosTracker.writeObject(ReceptionStats(
numBytesReceived, totalTimeSpent, splitIndex))
oosTracker.flush() oosTracker.flush()
// Receive ACK. No need to do anything with that // Receive ACK. No need to do anything with that
......
...@@ -96,8 +96,8 @@ extends Shuffle[K, V, C] with Logging { ...@@ -96,8 +96,8 @@ extends Shuffle[K, V, C] with Logging {
receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
combiners = new HashMap[K, C] combiners = new HashMap[K, C]
var threadPool = var threadPool = Shuffle.newDaemonFixedThreadPool(
Shuffle.newDaemonFixedThreadPool(Shuffle.MaxRxConnections) Shuffle.MaxRxConnections)
while (hasSplits < totalSplits) { while (hasSplits < totalSplits) {
var numThreadsToCreate = var numThreadsToCreate =
...@@ -106,15 +106,15 @@ extends Shuffle[K, V, C] with Logging { ...@@ -106,15 +106,15 @@ extends Shuffle[K, V, C] with Logging {
while (hasSplits < totalSplits && numThreadsToCreate > 0) { while (hasSplits < totalSplits && numThreadsToCreate > 0) {
// Receive which split to pull from the tracker // Receive which split to pull from the tracker
val splitIndex = getTrackerSelectedSplit(outputLocs) val splitIndex = getTrackerSelectedSplit(myId)
if (splitIndex != -1) { if (splitIndex != -1) {
val selectedSplitInfo = outputLocs(splitIndex) val selectedSplitInfo = outputLocs(splitIndex)
val requestSplit = val requestSplit =
"%d/%d/%d".format(shuffleId, selectedSplitInfo.inputId, myId) "%d/%d/%d".format(shuffleId, selectedSplitInfo.splitId, myId)
threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo, threadPool.execute(new ShuffleClient(splitIndex, selectedSplitInfo,
requestSplit)) requestSplit, myId))
// splitIndex is in transit. Will be unset in the ShuffleClient // splitIndex is in transit. Will be unset in the ShuffleClient
splitsInRequestBitVector.synchronized { splitsInRequestBitVector.synchronized {
...@@ -149,9 +149,9 @@ extends Shuffle[K, V, C] with Logging { ...@@ -149,9 +149,9 @@ extends Shuffle[K, V, C] with Logging {
}) })
} }
private def getLocalSplitInfo: SplitInfo = { private def getLocalSplitInfo(myId: Int): SplitInfo = {
var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress, var localSplitInfo = SplitInfo(InetAddress.getLocalHost.getHostAddress,
SplitInfo.UnusedParam, SplitInfo.UnusedParam) SplitInfo.UnusedParam, myId)
localSplitInfo.hasSplits = hasSplits localSplitInfo.hasSplits = hasSplits
...@@ -189,9 +189,9 @@ extends Shuffle[K, V, C] with Logging { ...@@ -189,9 +189,9 @@ extends Shuffle[K, V, C] with Logging {
} }
// Talks to the tracker and receives instruction // Talks to the tracker and receives instruction
private def getTrackerSelectedSplit(outputLocs: Array[SplitInfo]): Int = { private def getTrackerSelectedSplit(myId: Int): Int = {
// Local status of hasSplitsBitVector and splitsInRequestBitVector // Local status of hasSplitsBitVector and splitsInRequestBitVector
val localSplitInfo = getLocalSplitInfo val localSplitInfo = getLocalSplitInfo(myId)
// DO NOT talk to the tracker if all the required splits are already busy // DO NOT talk to the tracker if all the required splits are already busy
if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) { if (localSplitInfo.hasSplitsBitVector.cardinality == totalSplits) {
...@@ -294,17 +294,20 @@ extends Shuffle[K, V, C] with Logging { ...@@ -294,17 +294,20 @@ extends Shuffle[K, V, C] with Logging {
} }
else if (reducerIntention == else if (reducerIntention ==
TrackedCustomParallelLocalFileShuffle.ReducerLeaving) { TrackedCustomParallelLocalFileShuffle.ReducerLeaving) {
// Receive reducerSplitInfo and serverSplitIndex
val reducerSplitInfo = val reducerSplitInfo =
ois.readObject.asInstanceOf[SplitInfo] ois.readObject.asInstanceOf[SplitInfo]
val serverSplitIndex = ois.readObject.asInstanceOf[Int]
// Receive reception stats: how many blocks the reducer
// read in how much time and from where
val receptionStat =
ois.readObject.asInstanceOf[ReceptionStats]
// Update stats // Update stats
trackerStrategy.deleteReducerFrom(reducerSplitInfo, trackerStrategy.deleteReducerFrom(reducerSplitInfo,
serverSplitIndex) receptionStat)
// Send ACK // Send ACK
oos.writeObject(serverSplitIndex) oos.writeObject(receptionStat.serverSplitIndex)
oos.flush() oos.flush()
} }
else { else {
...@@ -375,7 +378,7 @@ extends Shuffle[K, V, C] with Logging { ...@@ -375,7 +378,7 @@ extends Shuffle[K, V, C] with Logging {
} }
class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo, class ShuffleClient(splitIndex: Int, serversplitInfo: SplitInfo,
requestSplit: String) requestSplit: String, myId: Int)
extends Thread with Logging { extends Thread with Logging {
private var peerSocketToSource: Socket = null private var peerSocketToSource: Socket = null
private var oosSource: ObjectOutputStream = null private var oosSource: ObjectOutputStream = null
...@@ -386,6 +389,10 @@ extends Shuffle[K, V, C] with Logging { ...@@ -386,6 +389,10 @@ extends Shuffle[K, V, C] with Logging {
// Make sure that multiple messages don't go to the tracker // Make sure that multiple messages don't go to the tracker
private var alreadySentLeavingNotification = false private var alreadySentLeavingNotification = false
// Keep track of bytes received and time spent
private var numBytesReceived = 0
private var totalTimeSpent = 0
override def run: Unit = { override def run: Unit = {
// Setup the timeout mechanism // Setup the timeout mechanism
var timeOutTask = new TimerTask { var timeOutTask = new TimerTask {
...@@ -467,6 +474,10 @@ extends Shuffle[K, V, C] with Logging { ...@@ -467,6 +474,10 @@ extends Shuffle[K, V, C] with Logging {
logInfo("END READ: " + requestPath) logInfo("END READ: " + requestPath)
val readTime = System.currentTimeMillis - readStartTime val readTime = System.currentTimeMillis - readStartTime
logInfo("Reading " + requestPath + " took " + readTime + " millis.") logInfo("Reading " + requestPath + " took " + readTime + " millis.")
// Update stats
numBytesReceived = numBytesReceived + requestedFileLen
totalTimeSpent = totalTimeSpent + readTime.toInt
} else { } else {
throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit) throw new SparkException("ShuffleServer " + serversplitInfo.hostAddress + " does not have " + requestSplit)
} }
...@@ -506,11 +517,12 @@ extends Shuffle[K, V, C] with Logging { ...@@ -506,11 +517,12 @@ extends Shuffle[K, V, C] with Logging {
oosTracker.flush() oosTracker.flush()
// Send reducerSplitInfo // Send reducerSplitInfo
oosTracker.writeObject(getLocalSplitInfo) oosTracker.writeObject(getLocalSplitInfo(myId))
oosTracker.flush() oosTracker.flush()
// Send serverSplitInfo so that tracker can update its stats // Send reception stats
oosTracker.writeObject(splitIndex) oosTracker.writeObject(ReceptionStats(
numBytesReceived, totalTimeSpent, splitIndex))
oosTracker.flush() oosTracker.flush()
// Receive ACK. No need to do anything with that // Receive ACK. No need to do anything with that
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment