diff --git a/src/scala/spark/Shuffle.scala b/src/scala/spark/Shuffle.scala index 4c5649b5378164bf162c33b90e7445b7ac412877..a80cfdb585d0cf981ce4a76f9bdcaa15239349cd 100644 --- a/src/scala/spark/Shuffle.scala +++ b/src/scala/spark/Shuffle.scala @@ -1,5 +1,9 @@ package spark +import java.net._ +import java.util.{BitSet} +import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} + /** * A trait for shuffle system. Given an input RDD and combiner functions * for PairRDDExtras.combineByKey(), returns an output RDD. @@ -13,3 +17,83 @@ trait Shuffle[K, V, C] { mergeCombiners: (C, C) => C) : RDD[(K, C)] } + +/** + * An object containing common shuffle config parameters + */ +private object Shuffle +extends Logging { + // ShuffleTracker info + private var MasterHostAddress_ = System.getProperty( + "spark.shuffle.masterHostAddress", InetAddress.getLocalHost.getHostAddress) + private var MasterTrackerPort_ = System.getProperty( + "spark.shuffle.masterTrackerPort", "22222").toInt + + // Used thoughout the code for small and large waits/timeouts + private var MinKnockInterval_ = System.getProperty( + "spark.shuffle.minKnockInterval", "1000").toInt + private var MaxKnockInterval_ = System.getProperty( + "spark.shuffle.maxKnockInterval", "5000").toInt + + // Maximum number of connections + private var MaxRxConnections_ = System.getProperty( + "spark.shuffle.maxRxConnections", "4").toInt + private var MaxTxConnections_ = System.getProperty( + "spark.shuffle.maxTxConnections", "8").toInt + + def MasterHostAddress = MasterHostAddress_ + def MasterTrackerPort = MasterTrackerPort_ + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + def MaxRxConnections = MaxRxConnections_ + def MaxTxConnections = MaxTxConnections_ + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } + + // Wrapper over newCachedThreadPool + def newDaemonCachedThreadPool: ThreadPoolExecutor = { + var threadPool = + Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } +} + +@serializable +case class SplitInfo(val hostAddress: String, val listenPort: Int, + val inputId: Int) { + + var hasSplits = 0 + var hasSplitsBitVector: BitSet = null +} + +object SplitInfo { + // Constants for special values of listenPort + val MappersBusy = -1 + + // Other constants + val UnusedParam = 0 +} diff --git a/src/scala/spark/ShuffleTrackerStrategy.scala b/src/scala/spark/ShuffleTrackerStrategy.scala new file mode 100644 index 0000000000000000000000000000000000000000..d982f48bb520a4fe6316094448ae1361cab47bb0 --- /dev/null +++ b/src/scala/spark/ShuffleTrackerStrategy.scala @@ -0,0 +1,81 @@ +package spark + +/** + * A trait for implementing tracker strategies for the shuffle system. + */ +trait ShuffleTrackerStrategy { + // Initialize + def initialize(outputLocs_ : Array[SplitInfo]): Unit + + // Select a split, update internal stats, and send it back + def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int + + // A reducer is done. Update internal stats + def deleteReducerFrom(reducerSplitInfo: SplitInfo, + serverSplitIndex: Int): Unit +} + +/** + * A simple ShuffleTrackerStrategy that tries to balance the total number of + * connections created for each mapper. + */ +class BalanceConnectionsShuffleTrackerStrategy +extends ShuffleTrackerStrategy with Logging { + var outputLocs: Array[SplitInfo] = null + var curConnectionsPerLoc: Array[Int] = null + var totalConnectionsPerLoc: Array[Int] = null + + // The order of elements in the outputLocs (splitIndex) is used to pass + // information back and forth between the tracker, mappers, and reducers + def initialize(outputLocs_ : Array[SplitInfo]): Unit = { + outputLocs = outputLocs_ + + // Now initialize other data structures + curConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0) + totalConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0) + } + + def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int = synchronized { + var minConnections = Int.MaxValue + var splitIndex = -1 + + for (i <- 0 until curConnectionsPerLoc.size) { + // TODO: Use of MaxRxConnections instead of MaxTxConnections is + // intentional here. MaxTxConnections is per machine whereas + // MaxRxConnections is per mapper/reducer. Will have to find a better way. + if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections && + totalConnectionsPerLoc(i) < minConnections && + !reducerSplitInfo.hasSplitsBitVector.get(i)) { + minConnections = totalConnectionsPerLoc(i) + splitIndex = i + } + } + + if (splitIndex != -1) { + curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1 + totalConnectionsPerLoc(splitIndex) = + totalConnectionsPerLoc(splitIndex) + 1 + + curConnectionsPerLoc.foreach { i => + print ("" + i + " ") + } + println("") + } + + return splitIndex + } + + def deleteReducerFrom(reducerSplitInfo: SplitInfo, + serverSplitIndex: Int): Unit = synchronized { + // Decrease number of active connections + curConnectionsPerLoc(serverSplitIndex) = + curConnectionsPerLoc(serverSplitIndex) - 1 + + assert(curConnectionsPerLoc(serverSplitIndex) >= 0) + + curConnectionsPerLoc.foreach { i => + print ("" + i + " ") + } + println("") + } +} diff --git a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala index 44822f5143cc24d4847c07dd5c240d10e484ade1..77be47cc3769bdb1c7236c9a0f852f4cdd6231ad 100644 --- a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala +++ b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala @@ -97,8 +97,7 @@ extends Shuffle[K, V, C] with Logging { combiners = new HashMap[K, C] var threadPool = - TrackedCustomParallelLocalFileShuffle.newDaemonFixedThreadPool( - TrackedCustomParallelLocalFileShuffle.MaxRxConnections) + Shuffle.newDaemonFixedThreadPool(Shuffle.MaxRxConnections) // Start consumer var shuffleConsumer = new ShuffleConsumer(mergeCombiners) @@ -107,8 +106,8 @@ extends Shuffle[K, V, C] with Logging { logInfo("ShuffleConsumer started...") while (hasSplits < totalSplits) { - var numThreadsToCreate = Math.min(totalSplits, - TrackedCustomParallelLocalFileShuffle.MaxRxConnections) - + var numThreadsToCreate = + Math.min(totalSplits, Shuffle.MaxRxConnections) - threadPool.getActiveCount while (hasSplits < totalSplits && numThreadsToCreate > 0) { @@ -133,7 +132,7 @@ extends Shuffle[K, V, C] with Logging { } // Sleep for a while before creating new threads - Thread.sleep(TrackedCustomParallelLocalFileShuffle.MinKnockInterval) + Thread.sleep(Shuffle.MinKnockInterval) } threadPool.shutdown() @@ -181,9 +180,8 @@ extends Shuffle[K, V, C] with Logging { // Talks to the tracker and receives instruction private def getTrackerSelectedSplit(outputLocs: Array[SplitInfo]): Int = { - val clientSocketToTracker = new Socket( - TrackedCustomParallelLocalFileShuffle.MasterHostAddress, - TrackedCustomParallelLocalFileShuffle.MasterTrackerPort) + val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, + Shuffle.MasterTrackerPort) val oosTracker = new ObjectOutputStream(clientSocketToTracker.getOutputStream) oosTracker.flush() @@ -219,8 +217,7 @@ extends Shuffle[K, V, C] with Logging { class ShuffleTracker(outputLocs: Array[SplitInfo]) extends Thread with Logging { - var threadPool = - TrackedCustomParallelLocalFileShuffle.newDaemonCachedThreadPool + var threadPool = Shuffle.newDaemonCachedThreadPool var serverSocket: ServerSocket = null // Create trackerStrategy object @@ -236,8 +233,7 @@ extends Shuffle[K, V, C] with Logging { trackerStrategy.initialize(outputLocs) override def run: Unit = { - serverSocket = new ServerSocket( - TrackedCustomParallelLocalFileShuffle.MasterTrackerPort) + serverSocket = new ServerSocket(Shuffle.MasterTrackerPort) logInfo("ShuffleTracker" + serverSocket) try { @@ -392,8 +388,7 @@ extends Shuffle[K, V, C] with Logging { } var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, - TrackedCustomParallelLocalFileShuffle.MaxKnockInterval) + timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval) // Create a temp variable to be used in different places val requestPath = "http://%s:%d/shuffle/%s".format( @@ -480,9 +475,8 @@ extends Shuffle[K, V, C] with Logging { // Connect to the tracker and update its stats private def sendLeavingNotification(): Unit = synchronized { if (!alreadySentLeavingNotification) { - val clientSocketToTracker = - new Socket(TrackedCustomParallelLocalFileShuffle.MasterHostAddress, - TrackedCustomParallelLocalFileShuffle.MasterTrackerPort) + val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress, + Shuffle.MasterTrackerPort) val oosTracker = new ObjectOutputStream(clientSocketToTracker.getOutputStream) oosTracker.flush() @@ -499,7 +493,7 @@ extends Shuffle[K, V, C] with Logging { oosTracker.writeObject(getLocalSplitInfo) oosTracker.flush() - // Send serverSplitInfo so that tracker can update its stats + // Send serverSplitInfo so that tracker can update its stats oosTracker.writeObject(splitIndex) oosTracker.flush() @@ -538,89 +532,6 @@ extends Shuffle[K, V, C] with Logging { } } -trait ShuffleTrackerStrategy { - def initialize(outputLocs_ : Array[SplitInfo]): Unit - def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int - def deleteReducerFrom(reducerSplitInfo: SplitInfo, - serverSplitIndex: Int): Unit -} - -class BalanceConnectionsShuffleTrackerStrategy -extends ShuffleTrackerStrategy with Logging { - var outputLocs: Array[SplitInfo] = null - var curConnectionsPerLoc: Array[Int] = null - var totalConnectionsPerLoc: Array[Int] = null - - def initialize(outputLocs_ : Array[SplitInfo]): Unit = { - outputLocs = outputLocs_ - - // Now initialize other data structures - curConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0) - totalConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0) - } - - def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int = synchronized { - var minConnections = Int.MaxValue - var splitIndex = -1 - - for (i <- 0 until curConnectionsPerLoc.size) { - // TODO: Use of MaxRxConnections instead of MaxTxConnections is - // intentional here. MaxTxConnections is per machine whereas - // MaxRxConnections is per mapper/reducer. Will have to find a better way. - if (curConnectionsPerLoc(i) < - TrackedCustomParallelLocalFileShuffle.MaxRxConnections && - totalConnectionsPerLoc(i) < minConnections && - !reducerSplitInfo.hasSplitsBitVector.get(i)) { - minConnections = totalConnectionsPerLoc(i) - splitIndex = i - } - } - - if (splitIndex != -1) { - curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1 - totalConnectionsPerLoc(splitIndex) = - totalConnectionsPerLoc(splitIndex) + 1 - - curConnectionsPerLoc.foreach { i => - print ("" + i + " ") - } - println("") - } - - return splitIndex - } - - def deleteReducerFrom(reducerSplitInfo: SplitInfo, - serverSplitIndex: Int): Unit = synchronized { - // Decrease number of active connections - curConnectionsPerLoc(serverSplitIndex) = - curConnectionsPerLoc(serverSplitIndex) - 1 - - assert(curConnectionsPerLoc(serverSplitIndex) >= 0) - - curConnectionsPerLoc.foreach { i => - print ("" + i + " ") - } - println("") - } -} - -@serializable -case class SplitInfo(val hostAddress: String, val listenPort: Int, - val inputId: Int) { - - var hasSplits = 0 - var hasSplitsBitVector: BitSet = null -} - -object SplitInfo { - // Constants for special values of listenPort - val MappersBusy = -1 - - // Other constants - val UnusedParam = 0 -} - object TrackedCustomParallelLocalFileShuffle extends Logging { // Tracker communication constants val ReducerEntering = 0 @@ -639,27 +550,6 @@ object TrackedCustomParallelLocalFileShuffle extends Logging { // Random number generator var ranGen = new Random - // Load config parameters - - // ShuffleTracker info - private var MasterHostAddress_ = System.getProperty( - "spark.shuffle.masterHostAddress", InetAddress.getLocalHost.getHostAddress) - private var MasterTrackerPort_ = System.getProperty( - "spark.shuffle.masterTrackerPort", "22222").toInt - - // Used thoughout the code for small and large waits/timeouts - private var MinKnockInterval_ = System.getProperty( - "spark.shuffle.minKnockInterval", "1000").toInt - private var MaxKnockInterval_ = System.getProperty( - "spark.shuffle.maxKnockInterval", "5000").toInt - - // Maximum number of connections - private var MaxRxConnections_ = System.getProperty( - "spark.shuffle.maxRxConnections", "4").toInt - private var MaxTxConnections_ = System.getProperty( - "spark.shuffle.maxTxConnections", "8").toInt - - private def initializeIfNeeded() = synchronized { if (!initialized) { // TODO: localDir should be created by some mechanism common to Spark @@ -702,15 +592,6 @@ object TrackedCustomParallelLocalFileShuffle extends Logging { } } - def MasterHostAddress = MasterHostAddress_ - def MasterTrackerPort = MasterTrackerPort_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - def MaxRxConnections = MaxRxConnections_ - def MaxTxConnections = MaxTxConnections_ - def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { initializeIfNeeded() val dir = new File(shuffleDir, shuffleId + "/" + inputId) @@ -723,41 +604,9 @@ object TrackedCustomParallelLocalFileShuffle extends Logging { nextShuffleId.getAndIncrement() } - // Returns a standard ThreadFactory except all threads are daemons - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread(r) - t.setDaemon(true) - return t - } - } - } - - // Wrapper over newFixedThreadPool - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = - Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } - - // Wrapper over newCachedThreadPool - def newDaemonCachedThreadPool: ThreadPoolExecutor = { - var threadPool = - Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } - class ShuffleServer extends Thread with Logging { - var threadPool = newDaemonFixedThreadPool( - CustomParallelLocalFileShuffle.MaxTxConnections) + var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections) var serverSocket: ServerSocket = null