Skip to content
Snippets Groups Projects
Commit a8411eb5 authored by Mosharaf Chowdhury's avatar Mosharaf Chowdhury
Browse files

- Moving common stuff to a separate Shuffle object.

 - Moved ShuffleTrackerStrategy to a separate file.
parent 1bc10ba6
No related branches found
No related tags found
No related merge requests found
package spark
import java.net._
import java.util.{BitSet}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
/**
* A trait for shuffle system. Given an input RDD and combiner functions
* for PairRDDExtras.combineByKey(), returns an output RDD.
......@@ -13,3 +17,83 @@ trait Shuffle[K, V, C] {
mergeCombiners: (C, C) => C)
: RDD[(K, C)]
}
/**
* An object containing common shuffle config parameters
*/
private object Shuffle
extends Logging {
// ShuffleTracker info
private var MasterHostAddress_ = System.getProperty(
"spark.shuffle.masterHostAddress", InetAddress.getLocalHost.getHostAddress)
private var MasterTrackerPort_ = System.getProperty(
"spark.shuffle.masterTrackerPort", "22222").toInt
// Used thoughout the code for small and large waits/timeouts
private var MinKnockInterval_ = System.getProperty(
"spark.shuffle.minKnockInterval", "1000").toInt
private var MaxKnockInterval_ = System.getProperty(
"spark.shuffle.maxKnockInterval", "5000").toInt
// Maximum number of connections
private var MaxRxConnections_ = System.getProperty(
"spark.shuffle.maxRxConnections", "4").toInt
private var MaxTxConnections_ = System.getProperty(
"spark.shuffle.maxTxConnections", "8").toInt
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
def MaxRxConnections = MaxRxConnections_
def MaxTxConnections = MaxTxConnections_
// Returns a standard ThreadFactory except all threads are daemons
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread(r)
t.setDaemon(true)
return t
}
}
}
// Wrapper over newFixedThreadPool
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
var threadPool =
Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory(newDaemonThreadFactory)
return threadPool
}
// Wrapper over newCachedThreadPool
def newDaemonCachedThreadPool: ThreadPoolExecutor = {
var threadPool =
Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory(newDaemonThreadFactory)
return threadPool
}
}
@serializable
case class SplitInfo(val hostAddress: String, val listenPort: Int,
val inputId: Int) {
var hasSplits = 0
var hasSplitsBitVector: BitSet = null
}
object SplitInfo {
// Constants for special values of listenPort
val MappersBusy = -1
// Other constants
val UnusedParam = 0
}
package spark
/**
* A trait for implementing tracker strategies for the shuffle system.
*/
trait ShuffleTrackerStrategy {
// Initialize
def initialize(outputLocs_ : Array[SplitInfo]): Unit
// Select a split, update internal stats, and send it back
def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int
// A reducer is done. Update internal stats
def deleteReducerFrom(reducerSplitInfo: SplitInfo,
serverSplitIndex: Int): Unit
}
/**
* A simple ShuffleTrackerStrategy that tries to balance the total number of
* connections created for each mapper.
*/
class BalanceConnectionsShuffleTrackerStrategy
extends ShuffleTrackerStrategy with Logging {
var outputLocs: Array[SplitInfo] = null
var curConnectionsPerLoc: Array[Int] = null
var totalConnectionsPerLoc: Array[Int] = null
// The order of elements in the outputLocs (splitIndex) is used to pass
// information back and forth between the tracker, mappers, and reducers
def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
outputLocs = outputLocs_
// Now initialize other data structures
curConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0)
totalConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0)
}
def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int = synchronized {
var minConnections = Int.MaxValue
var splitIndex = -1
for (i <- 0 until curConnectionsPerLoc.size) {
// TODO: Use of MaxRxConnections instead of MaxTxConnections is
// intentional here. MaxTxConnections is per machine whereas
// MaxRxConnections is per mapper/reducer. Will have to find a better way.
if (curConnectionsPerLoc(i) < Shuffle.MaxRxConnections &&
totalConnectionsPerLoc(i) < minConnections &&
!reducerSplitInfo.hasSplitsBitVector.get(i)) {
minConnections = totalConnectionsPerLoc(i)
splitIndex = i
}
}
if (splitIndex != -1) {
curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1
totalConnectionsPerLoc(splitIndex) =
totalConnectionsPerLoc(splitIndex) + 1
curConnectionsPerLoc.foreach { i =>
print ("" + i + " ")
}
println("")
}
return splitIndex
}
def deleteReducerFrom(reducerSplitInfo: SplitInfo,
serverSplitIndex: Int): Unit = synchronized {
// Decrease number of active connections
curConnectionsPerLoc(serverSplitIndex) =
curConnectionsPerLoc(serverSplitIndex) - 1
assert(curConnectionsPerLoc(serverSplitIndex) >= 0)
curConnectionsPerLoc.foreach { i =>
print ("" + i + " ")
}
println("")
}
}
......@@ -97,8 +97,7 @@ extends Shuffle[K, V, C] with Logging {
combiners = new HashMap[K, C]
var threadPool =
TrackedCustomParallelLocalFileShuffle.newDaemonFixedThreadPool(
TrackedCustomParallelLocalFileShuffle.MaxRxConnections)
Shuffle.newDaemonFixedThreadPool(Shuffle.MaxRxConnections)
// Start consumer
var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
......@@ -107,8 +106,8 @@ extends Shuffle[K, V, C] with Logging {
logInfo("ShuffleConsumer started...")
while (hasSplits < totalSplits) {
var numThreadsToCreate = Math.min(totalSplits,
TrackedCustomParallelLocalFileShuffle.MaxRxConnections) -
var numThreadsToCreate =
Math.min(totalSplits, Shuffle.MaxRxConnections) -
threadPool.getActiveCount
while (hasSplits < totalSplits && numThreadsToCreate > 0) {
......@@ -133,7 +132,7 @@ extends Shuffle[K, V, C] with Logging {
}
// Sleep for a while before creating new threads
Thread.sleep(TrackedCustomParallelLocalFileShuffle.MinKnockInterval)
Thread.sleep(Shuffle.MinKnockInterval)
}
threadPool.shutdown()
......@@ -181,9 +180,8 @@ extends Shuffle[K, V, C] with Logging {
// Talks to the tracker and receives instruction
private def getTrackerSelectedSplit(outputLocs: Array[SplitInfo]): Int = {
val clientSocketToTracker = new Socket(
TrackedCustomParallelLocalFileShuffle.MasterHostAddress,
TrackedCustomParallelLocalFileShuffle.MasterTrackerPort)
val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
Shuffle.MasterTrackerPort)
val oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
......@@ -219,8 +217,7 @@ extends Shuffle[K, V, C] with Logging {
class ShuffleTracker(outputLocs: Array[SplitInfo])
extends Thread with Logging {
var threadPool =
TrackedCustomParallelLocalFileShuffle.newDaemonCachedThreadPool
var threadPool = Shuffle.newDaemonCachedThreadPool
var serverSocket: ServerSocket = null
// Create trackerStrategy object
......@@ -236,8 +233,7 @@ extends Shuffle[K, V, C] with Logging {
trackerStrategy.initialize(outputLocs)
override def run: Unit = {
serverSocket = new ServerSocket(
TrackedCustomParallelLocalFileShuffle.MasterTrackerPort)
serverSocket = new ServerSocket(Shuffle.MasterTrackerPort)
logInfo("ShuffleTracker" + serverSocket)
try {
......@@ -392,8 +388,7 @@ extends Shuffle[K, V, C] with Logging {
}
var timeOutTimer = new Timer
timeOutTimer.schedule(timeOutTask,
TrackedCustomParallelLocalFileShuffle.MaxKnockInterval)
timeOutTimer.schedule(timeOutTask, Shuffle.MaxKnockInterval)
// Create a temp variable to be used in different places
val requestPath = "http://%s:%d/shuffle/%s".format(
......@@ -480,9 +475,8 @@ extends Shuffle[K, V, C] with Logging {
// Connect to the tracker and update its stats
private def sendLeavingNotification(): Unit = synchronized {
if (!alreadySentLeavingNotification) {
val clientSocketToTracker =
new Socket(TrackedCustomParallelLocalFileShuffle.MasterHostAddress,
TrackedCustomParallelLocalFileShuffle.MasterTrackerPort)
val clientSocketToTracker = new Socket(Shuffle.MasterHostAddress,
Shuffle.MasterTrackerPort)
val oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
......@@ -499,7 +493,7 @@ extends Shuffle[K, V, C] with Logging {
oosTracker.writeObject(getLocalSplitInfo)
oosTracker.flush()
// Send serverSplitInfo so that tracker can update its stats
// Send serverSplitInfo so that tracker can update its stats
oosTracker.writeObject(splitIndex)
oosTracker.flush()
......@@ -538,89 +532,6 @@ extends Shuffle[K, V, C] with Logging {
}
}
trait ShuffleTrackerStrategy {
def initialize(outputLocs_ : Array[SplitInfo]): Unit
def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int
def deleteReducerFrom(reducerSplitInfo: SplitInfo,
serverSplitIndex: Int): Unit
}
class BalanceConnectionsShuffleTrackerStrategy
extends ShuffleTrackerStrategy with Logging {
var outputLocs: Array[SplitInfo] = null
var curConnectionsPerLoc: Array[Int] = null
var totalConnectionsPerLoc: Array[Int] = null
def initialize(outputLocs_ : Array[SplitInfo]): Unit = {
outputLocs = outputLocs_
// Now initialize other data structures
curConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0)
totalConnectionsPerLoc = Array.tabulate(outputLocs.size)(_ => 0)
}
def selectSplitAndAddReducer(reducerSplitInfo: SplitInfo): Int = synchronized {
var minConnections = Int.MaxValue
var splitIndex = -1
for (i <- 0 until curConnectionsPerLoc.size) {
// TODO: Use of MaxRxConnections instead of MaxTxConnections is
// intentional here. MaxTxConnections is per machine whereas
// MaxRxConnections is per mapper/reducer. Will have to find a better way.
if (curConnectionsPerLoc(i) <
TrackedCustomParallelLocalFileShuffle.MaxRxConnections &&
totalConnectionsPerLoc(i) < minConnections &&
!reducerSplitInfo.hasSplitsBitVector.get(i)) {
minConnections = totalConnectionsPerLoc(i)
splitIndex = i
}
}
if (splitIndex != -1) {
curConnectionsPerLoc(splitIndex) = curConnectionsPerLoc(splitIndex) + 1
totalConnectionsPerLoc(splitIndex) =
totalConnectionsPerLoc(splitIndex) + 1
curConnectionsPerLoc.foreach { i =>
print ("" + i + " ")
}
println("")
}
return splitIndex
}
def deleteReducerFrom(reducerSplitInfo: SplitInfo,
serverSplitIndex: Int): Unit = synchronized {
// Decrease number of active connections
curConnectionsPerLoc(serverSplitIndex) =
curConnectionsPerLoc(serverSplitIndex) - 1
assert(curConnectionsPerLoc(serverSplitIndex) >= 0)
curConnectionsPerLoc.foreach { i =>
print ("" + i + " ")
}
println("")
}
}
@serializable
case class SplitInfo(val hostAddress: String, val listenPort: Int,
val inputId: Int) {
var hasSplits = 0
var hasSplitsBitVector: BitSet = null
}
object SplitInfo {
// Constants for special values of listenPort
val MappersBusy = -1
// Other constants
val UnusedParam = 0
}
object TrackedCustomParallelLocalFileShuffle extends Logging {
// Tracker communication constants
val ReducerEntering = 0
......@@ -639,27 +550,6 @@ object TrackedCustomParallelLocalFileShuffle extends Logging {
// Random number generator
var ranGen = new Random
// Load config parameters
// ShuffleTracker info
private var MasterHostAddress_ = System.getProperty(
"spark.shuffle.masterHostAddress", InetAddress.getLocalHost.getHostAddress)
private var MasterTrackerPort_ = System.getProperty(
"spark.shuffle.masterTrackerPort", "22222").toInt
// Used thoughout the code for small and large waits/timeouts
private var MinKnockInterval_ = System.getProperty(
"spark.shuffle.minKnockInterval", "1000").toInt
private var MaxKnockInterval_ = System.getProperty(
"spark.shuffle.maxKnockInterval", "5000").toInt
// Maximum number of connections
private var MaxRxConnections_ = System.getProperty(
"spark.shuffle.maxRxConnections", "4").toInt
private var MaxTxConnections_ = System.getProperty(
"spark.shuffle.maxTxConnections", "8").toInt
private def initializeIfNeeded() = synchronized {
if (!initialized) {
// TODO: localDir should be created by some mechanism common to Spark
......@@ -702,15 +592,6 @@ object TrackedCustomParallelLocalFileShuffle extends Logging {
}
}
def MasterHostAddress = MasterHostAddress_
def MasterTrackerPort = MasterTrackerPort_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
def MaxRxConnections = MaxRxConnections_
def MaxTxConnections = MaxTxConnections_
def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
initializeIfNeeded()
val dir = new File(shuffleDir, shuffleId + "/" + inputId)
......@@ -723,41 +604,9 @@ object TrackedCustomParallelLocalFileShuffle extends Logging {
nextShuffleId.getAndIncrement()
}
// Returns a standard ThreadFactory except all threads are daemons
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread(r)
t.setDaemon(true)
return t
}
}
}
// Wrapper over newFixedThreadPool
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
var threadPool =
Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory(newDaemonThreadFactory)
return threadPool
}
// Wrapper over newCachedThreadPool
def newDaemonCachedThreadPool: ThreadPoolExecutor = {
var threadPool =
Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory(newDaemonThreadFactory)
return threadPool
}
class ShuffleServer
extends Thread with Logging {
var threadPool = newDaemonFixedThreadPool(
CustomParallelLocalFileShuffle.MaxTxConnections)
var threadPool = Shuffle.newDaemonFixedThreadPool(Shuffle.MaxTxConnections)
var serverSocket: ServerSocket = null
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment