"git@gitlab.engr.illinois.edu:croft1/floodlight.git" did not exist on "4ecd679298b83a833a5d7f3ec8821a9770e530df"
-
Mosharaf Chowdhury authoredMosharaf Chowdhury authored
LocalFileShuffle.scala 12.96 KiB
package spark
import java.io._
import java.net._
import java.util.{Timer, TimerTask, UUID}
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory}
import scala.collection.mutable.{ArrayBuffer, HashMap}
/**
* A simple implementation of shuffle using local files served through HTTP.
*
* TODO: Add support for compression when spark.compress is set to true.
*/
@serializable
class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
override def compute(input: RDD[(K, V)],
numOutputSplits: Int,
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C)
: RDD[(K, C)] =
{
val sc = input.sparkContext
val shuffleId = LocalFileShuffle.newShuffleId()
logInfo("Shuffle ID: " + shuffleId)
val splitRdd = new NumberedSplitRDD(input)
val numInputSplits = splitRdd.splits.size
// Run a parallel map and collect to write the intermediate data files,
// returning a list of inputSplitId -> serverUri pairs
val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
val myIndex = pair._1
val myIterator = pair._2
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
for ((k, v) <- myIterator) {
var bucketId = k.hashCode % numOutputSplits
if (bucketId < 0) { // Fix bucket ID if hash code was negative
bucketId += numOutputSplits
}
val bucket = buckets(bucketId)
bucket(k) = bucket.get(k) match {
case Some(c) => mergeValue(c, v)
case None => createCombiner(v)
}
}
for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
val writeStartTime = System.currentTimeMillis
logInfo ("BEGIN WRITE: " + file)
val out = new ObjectOutputStream(new FileOutputStream(file))
buckets(i).foreach(pair => out.writeObject(pair))
out.close()
logInfo ("END WRITE: " + file)
val writeTime = (System.currentTimeMillis - writeStartTime)
logInfo ("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
}
(myIndex, LocalFileShuffle.serverAddress, LocalFileShuffle.serverPort)
}).collect()
val splitsByUri = new ArrayBuffer[(String, Int, Int)]
for ((inputId, serverAddress, serverPort) <- outputLocs) {
splitsByUri += ((serverAddress, serverPort, inputId))
}
// TODO: Could broadcast splitsByUri
// Return an RDD that does each of the merges for a given partition
val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
return indexes.flatMap((myId: Int) => {
val combiners = new HashMap[K, C]
for ((serverAddress, serverPort, inputId) <- splitsByUri) {
val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId)
val shuffleClient = new ShuffleClient(serverAddress, serverPort, requestPath)
val readStartTime = System.currentTimeMillis
logInfo ("BEGIN READ: " + requestPath)
shuffleClient.start
shuffleClient.join
val inputStream = new ObjectInputStream (
new ByteArrayInputStream(shuffleClient.byteArray))
try {
while (true) {
val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
combiners(k) = combiners.get(k) match {
case Some(oldC) => mergeCombiners(oldC, c)
case None => c
}
}
} catch {
case e: EOFException => {}
}
inputStream.close
logInfo ("END READ: " + requestPath)
val readTime = (System.currentTimeMillis - readStartTime)
logInfo ("Reading " + requestPath + " took " + readTime + " millis.")
}
combiners
})
}
}
object LocalFileShuffle extends Logging {
private var initialized = false
private var nextShuffleId = new AtomicLong(0)
// Variables initialized by initializeIfNeeded()
private var shuffleDir: File = null
private var shuffleServer: ShuffleServer = null
private var serverAddress = InetAddress.getLocalHost.getHostAddress
private var serverPort: Int = -1
private def initializeIfNeeded() = synchronized {
if (!initialized) {
// TODO: localDir should be created by some mechanism common to Spark
// so that it can be shared among shuffle, broadcast, etc
val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
var tries = 0
var foundLocalDir = false
var localDir: File = null
var localDirUuid: UUID = null
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDirUuid = UUID.randomUUID()
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists()) {
localDir.mkdirs()
foundLocalDir = true
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir failed", e)
}
}
if (!foundLocalDir) {
logError("Failed 10 attempts to create local dir in " + localDirRoot)
System.exit(1)
}
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
// Create and start the shuffleServer
shuffleServer = new ShuffleServer
shuffleServer.setDaemon (true)
shuffleServer.start
logInfo ("ShuffleServer started...")
initialized = true
}
}
def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
initializeIfNeeded()
val dir = new File(shuffleDir, shuffleId + "/" + inputId)
dir.mkdirs()
val file = new File(dir, "" + outputId)
return file
}
def newShuffleId(): Long = {
nextShuffleId.getAndIncrement()
}
// Returns a standard ThreadFactory except all threads are daemons
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread (r)
t.setDaemon (true)
return t
}
}
}
// Wrapper over newFixedThreadPool
def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
var threadPool =
Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
class ShuffleServer
extends Thread with Logging {
// TODO: Set config param
var threadPool = newDaemonFixedThreadPool(2)
var serverSocket: ServerSocket = null
override def run: Unit = {
serverSocket = new ServerSocket (0)
serverPort = serverSocket.getLocalPort
logInfo ("ShuffleServer started with " + serverSocket)
logInfo ("Local URI: " + serverAddress + ":" + serverPort)
try {
while (true) {
var clientSocket: Socket = null
try {
clientSocket = serverSocket.accept
} catch {
case e: Exception => { }
}
if (clientSocket != null) {
logInfo ("Serve: Accepted new client connection:" + clientSocket)
try {
threadPool.execute (new ShuffleServerThread (clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => {
clientSocket.close
}
}
}
}
} finally {
if (serverSocket != null) {
logInfo ("ShuffleServer now stopping...")
serverSocket.close
}
}
// Shutdown the thread pool
threadPool.shutdown
}
class ShuffleServerThread (val clientSocket: Socket)
extends Thread with Logging {
private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
os.flush
private val oos = new ObjectOutputStream (os)
oos.flush
private val ois = new ObjectInputStream (clientSocket.getInputStream)
logInfo ("new ShuffleServerThread is running")
override def run: Unit = {
try {
// Receive requestPath from the receiver
var requestPath = ois.readObject.asInstanceOf[String]
logInfo("requestPath: " + shuffleDir + "/" + requestPath)
// Open the file
var requestedFile: File = null
var requestedFileLen = -1
try {
requestedFile = new File(shuffleDir + "/" + requestPath)
requestedFileLen = requestedFile.length.toInt
} catch {
case e: Exception => { }
}
// Send the lendth of the requestPath to let the receiver know that
// transfer is about to start
// In the case of receiver timeout and connection close, this will
// throw a java.net.SocketException: Broken pipe
oos.writeObject(requestedFileLen)
oos.flush
logInfo ("requestedFileLen = " + requestedFileLen)
// Read and send the requested file
if (requestedFileLen != -1) {
// Read
var byteArray = new Array[Byte](requestedFileLen)
val bis =
new BufferedInputStream (new FileInputStream (requestedFile))
var bytesRead = bis.read (byteArray, 0, byteArray.length)
var alreadyRead = bytesRead
while (alreadyRead < requestedFileLen) {
bytesRead = bis.read(byteArray, alreadyRead,
(byteArray.length - alreadyRead))
if(bytesRead > 0) {
alreadyRead = alreadyRead + bytesRead
}
}
bis.close
// Send
os.write (byteArray, 0, byteArray.length)
os.flush
} else {
// Close the connection
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
// Exception can happen if the receiver stops receiving
case e: Exception => {
logInfo ("ShuffleServerThread had a " + e)
}
} finally {
logInfo ("ShuffleServerThread is closing streams and sockets")
ois.close
// TODO: Following can cause "java.net.SocketException: Socket closed"
oos.close
clientSocket.close
}
}
}
}
}
class ShuffleClient (hostAddress: String, listenPort: Int, requestPath: String)
extends Thread with Logging {
private var peerSocketToSource: Socket = null
private var oosSource: ObjectOutputStream = null
private var oisSource: ObjectInputStream = null
var byteArray: Array[Byte] = null
override def run: Unit = {
// Setup the timeout mechanism
var timeOutTask = new TimerTask {
override def run: Unit = {
cleanUpConnections
}
}
var timeOutTimer = new Timer
// TODO: Set wait timer
timeOutTimer.schedule (timeOutTask, 1000)
logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath))
try {
// Connect to the source
peerSocketToSource = new Socket (hostAddress, listenPort)
oosSource =
new ObjectOutputStream (peerSocketToSource.getOutputStream)
oosSource.flush
var isSource = peerSocketToSource.getInputStream
oisSource = new ObjectInputStream (isSource)
// Send the request
oosSource.writeObject(requestPath)
// Receive the length of the requested file
var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
logInfo ("Received requestedFileLen = " + requestedFileLen)
// Turn the timer OFF, if the sender responds before timeout
timeOutTimer.cancel
// Receive the file
if (requestedFileLen != -1) {
byteArray = new Array[Byte] (requestedFileLen)
var bytesRead = isSource.read (byteArray, 0, byteArray.length)
var alreadyRead = bytesRead
while (alreadyRead < requestedFileLen) {
bytesRead = isSource.read(byteArray, alreadyRead,
(byteArray.length - alreadyRead))
if(bytesRead > 0) {
alreadyRead = alreadyRead + bytesRead
}
}
} else {
throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath)
}
} catch {
// EOFException is expected to happen because sender can break
// connection due to timeout
case eofe: java.io.EOFException => { }
case e: Exception => {
logInfo ("ShuffleClient had a " + e)
}
} finally {
cleanUpConnections
}
}
private def cleanUpConnections: Unit = {
if (oisSource != null) {
oisSource.close
}
if (oosSource != null) {
oosSource.close
}
if (peerSocketToSource != null) {
peerSocketToSource.close
}
}
}