diff --git a/src/scala/spark/CustomBlockedLocalFileShuffle.scala b/src/scala/spark/CustomBlockedLocalFileShuffle.scala index a28322196bcf276951df427c095691ed9ef10862..f3d20c9f7e781d52971a594556b66388a2de31bf 100644 --- a/src/scala/spark/CustomBlockedLocalFileShuffle.scala +++ b/src/scala/spark/CustomBlockedLocalFileShuffle.scala @@ -21,7 +21,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} * TODO: Add support for compression when spark.compress is set to true. */ @serializable -class CustomBlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { +class CustomBlockedLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { @transient var totalSplits = 0 @transient var hasSplits = 0 @@ -76,8 +77,8 @@ class CustomBlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Loggi buckets(i).foreach(pair => { // Open a new file if necessary if (!isDirty) { - file = CustomBlockedLocalFileShuffle.getOutputFile(shuffleId, myIndex, i, - blockNum) + file = CustomBlockedLocalFileShuffle.getOutputFile(shuffleId, + myIndex, i, blockNum) writeStartTime = System.currentTimeMillis logInfo("BEGIN WRITE: " + file) @@ -110,8 +111,8 @@ class CustomBlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Loggi } // Write the BLOCKNUM file - file = - CustomBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId, myIndex, i) + file = CustomBlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId, + myIndex, i) out = new ObjectOutputStream(new FileOutputStream(file)) out.writeObject(blockNum) out.close() @@ -166,6 +167,8 @@ class CustomBlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Loggi // Sleep for a while before creating new threads Thread.sleep(CustomBlockedLocalFileShuffle.MinKnockInterval) } + + threadPool.shutdown() combiners }) } diff --git a/src/scala/spark/CustomParallelLocalFileShuffle.scala b/src/scala/spark/CustomParallelLocalFileShuffle.scala index 8f5fad4a85fb57c738fd977849fb6f418ef53b73..aa08d1a195a002bd4cb4ea2e4a6faf64e800ede4 100644 --- a/src/scala/spark/CustomParallelLocalFileShuffle.scala +++ b/src/scala/spark/CustomParallelLocalFileShuffle.scala @@ -16,7 +16,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} * TODO: Add support for compression when spark.compress is set to true. */ @serializable -class CustomParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { +class CustomParallelLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { @transient var totalSplits = 0 @transient var hasSplits = 0 @transient var hasSplitsBitVector: BitSet = null @@ -57,17 +58,19 @@ class CustomParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logg } for (i <- 0 until numOutputSplits) { - val file = CustomParallelLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + val file = CustomParallelLocalFileShuffle.getOutputFile(shuffleId, + myIndex, i) val writeStartTime = System.currentTimeMillis - logInfo ("BEGIN WRITE: " + file) + logInfo("BEGIN WRITE: " + file) val out = new ObjectOutputStream(new FileOutputStream(file)) buckets(i).foreach(pair => out.writeObject(pair)) out.close() - logInfo ("END WRITE: " + file) - val writeTime = (System.currentTimeMillis - writeStartTime) - logInfo ("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + logInfo("END WRITE: " + file) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") } - (myIndex, CustomParallelLocalFileShuffle.serverAddress, CustomParallelLocalFileShuffle.serverPort) + (myIndex, CustomParallelLocalFileShuffle.serverAddress, + CustomParallelLocalFileShuffle.serverPort) }).collect() val splitsByUri = new ArrayBuffer[(String, Int, Int)] @@ -82,16 +85,16 @@ class CustomParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logg return indexes.flatMap((myId: Int) => { totalSplits = splitsByUri.size hasSplits = 0 - hasSplitsBitVector = new BitSet (totalSplits) - splitsInRequestBitVector = new BitSet (totalSplits) + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) combiners = new HashMap[K, C] - var threadPool = - CustomParallelLocalFileShuffle.newDaemonFixedThreadPool (CustomParallelLocalFileShuffle.MaxConnections) + var threadPool = CustomParallelLocalFileShuffle.newDaemonFixedThreadPool( + CustomParallelLocalFileShuffle.MaxConnections) while (hasSplits < totalSplits) { - var numThreadsToCreate = - Math.min (totalSplits, CustomParallelLocalFileShuffle.MaxConnections) - + var numThreadsToCreate = Math.min(totalSplits, + CustomParallelLocalFileShuffle.MaxConnections) - threadPool.getActiveCount while (hasSplits < totalSplits && numThreadsToCreate > 0) { @@ -99,15 +102,15 @@ class CustomParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logg val splitIndex = selectRandomSplit if (splitIndex != -1) { - val (serverAddress, serverPort, inputId) = splitsByUri (splitIndex) + val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId) - threadPool.execute (new ShuffleClient (splitIndex, serverAddress, + threadPool.execute(new ShuffleClient(splitIndex, serverAddress, serverPort, requestPath, mergeCombiners)) // splitIndex is in transit. Will be unset in the ShuffleClient splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set (splitIndex) + splitsInRequestBitVector.set(splitIndex) } } @@ -115,10 +118,10 @@ class CustomParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logg } // Sleep for a while before creating new threads - Thread.sleep (CustomParallelLocalFileShuffle.MinKnockInterval) + Thread.sleep(CustomParallelLocalFileShuffle.MinKnockInterval) } - threadPool.shutdown + threadPool.shutdown() combiners }) } @@ -135,13 +138,14 @@ class CustomParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logg } if (requiredSplits.size > 0) { - requiredSplits(CustomParallelLocalFileShuffle.ranGen.nextInt (requiredSplits.size)) + requiredSplits(CustomParallelLocalFileShuffle.ranGen.nextInt( + requiredSplits.size)) } else { -1 } } - class ShuffleClient (splitIndex: Int, hostAddress: String, listenPort: Int, + class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int, requestPath: String, mergeCombiners: (C, C) => C) extends Thread with Logging { private var peerSocketToSource: Socket = null @@ -152,43 +156,44 @@ class CustomParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logg override def run: Unit = { val readStartTime = System.currentTimeMillis - logInfo ("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath)) + logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath)) // Setup the timeout mechanism var timeOutTask = new TimerTask { override def run: Unit = { - cleanUpConnections + cleanUpConnections() } } var timeOutTimer = new Timer - timeOutTimer.schedule (timeOutTask, CustomParallelLocalFileShuffle.MaxKnockInterval) + timeOutTimer.schedule(timeOutTask, + CustomParallelLocalFileShuffle.MaxKnockInterval) - logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath)) + logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath)) try { // Connect to the source - peerSocketToSource = new Socket (hostAddress, listenPort) + peerSocketToSource = new Socket(hostAddress, listenPort) oosSource = - new ObjectOutputStream (peerSocketToSource.getOutputStream) - oosSource.flush + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() var isSource = peerSocketToSource.getInputStream - oisSource = new ObjectInputStream (isSource) + oisSource = new ObjectInputStream(isSource) // Send the request oosSource.writeObject(requestPath) // Receive the length of the requested file var requestedFileLen = oisSource.readObject.asInstanceOf[Int] - logInfo ("Received requestedFileLen = " + requestedFileLen) + logInfo("Received requestedFileLen = " + requestedFileLen) // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel + timeOutTimer.cancel() // Receive the file if (requestedFileLen != -1) { // Add this to combiners - val inputStream = new ObjectInputStream (isSource) + val inputStream = new ObjectInputStream(isSource) try{ while (true) { @@ -203,24 +208,24 @@ class CustomParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logg } catch { case e: EOFException => { } } - inputStream.close + inputStream.close() // Reception completed. Update stats. hasSplitsBitVector.synchronized { - hasSplitsBitVector.set (splitIndex) + hasSplitsBitVector.set(splitIndex) } hasSplits += 1 // We have received splitIndex splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set (splitIndex, false) + splitsInRequestBitVector.set(splitIndex, false) } receptionSucceeded = true - logInfo ("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath)) - val readTime = (System.currentTimeMillis - readStartTime) - logInfo ("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath) + " took " + readTime + " millis.") + logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath)) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath) + " took " + readTime + " millis.") } else { throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath) } @@ -229,28 +234,28 @@ class CustomParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logg // connection due to timeout case eofe: java.io.EOFException => { } case e: Exception => { - logInfo ("ShuffleClient had a " + e) + logInfo("ShuffleClient had a " + e) } } finally { // If reception failed, unset for future retry if (!receptionSucceeded) { splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set (splitIndex, false) + splitsInRequestBitVector.set(splitIndex, false) } } - cleanUpConnections + cleanUpConnections() } } - private def cleanUpConnections: Unit = { + private def cleanUpConnections(): Unit = { if (oisSource != null) { - oisSource.close + oisSource.close() } if (oosSource != null) { - oosSource.close + oosSource.close() } if (peerSocketToSource != null) { - peerSocketToSource.close + peerSocketToSource.close() } } } @@ -280,13 +285,13 @@ object CustomParallelLocalFileShuffle extends Logging { private def initializeIfNeeded() = synchronized { if (!initialized) { // Load config parameters - MinKnockInterval_ = - System.getProperty ("spark.parallelLocalFileShuffle.MinKnockInterval", "1000").toInt - MaxKnockInterval_ = - System.getProperty ("spark.parallelLocalFileShuffle.MaxKnockInterval", "5000").toInt + MinKnockInterval_ = System.getProperty( + "spark.parallelLocalFileShuffle.MinKnockInterval", "1000").toInt + MaxKnockInterval_ = System.getProperty( + "spark.parallelLocalFileShuffle.MaxKnockInterval", "5000").toInt - MaxConnections_ = - System.getProperty ("spark.parallelLocalFileShuffle.MaxConnections", "4").toInt + MaxConnections_ = System.getProperty( + "spark.parallelLocalFileShuffle.MaxConnections", "4").toInt // TODO: localDir should be created by some mechanism common to Spark // so that it can be shared among shuffle, broadcast, etc @@ -299,9 +304,9 @@ object CustomParallelLocalFileShuffle extends Logging { while (!foundLocalDir && tries < 10) { tries += 1 try { - localDirUuid = UUID.randomUUID() + localDirUuid = UUID.randomUUID localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists()) { + if (!localDir.exists) { localDir.mkdirs() foundLocalDir = true } @@ -320,9 +325,9 @@ object CustomParallelLocalFileShuffle extends Logging { // Create and start the shuffleServer shuffleServer = new ShuffleServer - shuffleServer.setDaemon (true) + shuffleServer.setDaemon(true) shuffleServer.start - logInfo ("ShuffleServer started...") + logInfo("ShuffleServer started...") initialized = true } @@ -349,77 +354,78 @@ object CustomParallelLocalFileShuffle extends Logging { private def newDaemonThreadFactory: ThreadFactory = { new ThreadFactory { def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread (r) - t.setDaemon (true) + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) return t } } } // Wrapper over newFixedThreadPool - def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = { + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { var threadPool = - Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor] + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - threadPool.setThreadFactory (newDaemonThreadFactory) + threadPool.setThreadFactory(newDaemonThreadFactory) return threadPool - } + } class ShuffleServer extends Thread with Logging { - var threadPool = newDaemonFixedThreadPool(CustomParallelLocalFileShuffle.MaxConnections) + var threadPool = + newDaemonFixedThreadPool(CustomParallelLocalFileShuffle.MaxConnections) var serverSocket: ServerSocket = null override def run: Unit = { - serverSocket = new ServerSocket (0) + serverSocket = new ServerSocket(0) serverPort = serverSocket.getLocalPort - logInfo ("ShuffleServer started with " + serverSocket) - logInfo ("Local URI: http://" + serverAddress + ":" + serverPort) + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) try { while (true) { var clientSocket: Socket = null try { - clientSocket = serverSocket.accept + clientSocket = serverSocket.accept() } catch { case e: Exception => { } } if (clientSocket != null) { - logInfo ("Serve: Accepted new client connection:" + clientSocket) + logInfo("Serve: Accepted new client connection:" + clientSocket) try { - threadPool.execute (new ShuffleServerThread (clientSocket)) + threadPool.execute(new ShuffleServerThread(clientSocket)) } catch { // In failure, close socket here; else, the thread will close it case ioe: IOException => { - clientSocket.close + clientSocket.close() } } } } } finally { if (serverSocket != null) { - logInfo ("ShuffleServer now stopping...") - serverSocket.close + logInfo("ShuffleServer now stopping...") + serverSocket.close() } } // Shutdown the thread pool - threadPool.shutdown + threadPool.shutdown() } - class ShuffleServerThread (val clientSocket: Socket) + class ShuffleServerThread(val clientSocket: Socket) extends Thread with Logging { private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] - os.flush - private val bos = new BufferedOutputStream (os) - bos.flush - private val oos = new ObjectOutputStream (os) - oos.flush - private val ois = new ObjectInputStream (clientSocket.getInputStream) + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) - logInfo ("new ShuffleServerThread is running") + logInfo("new ShuffleServerThread is running") override def run: Unit = { try { @@ -442,49 +448,49 @@ object CustomParallelLocalFileShuffle extends Logging { // In the case of receiver timeout and connection close, this will // throw a java.net.SocketException: Broken pipe oos.writeObject(requestedFileLen) - oos.flush + oos.flush() - logInfo ("requestedFileLen = " + requestedFileLen) + logInfo("requestedFileLen = " + requestedFileLen) // Read and send the requested file if (requestedFileLen != -1) { // Read var byteArray = new Array[Byte](requestedFileLen) val bis = - new BufferedInputStream (new FileInputStream (requestedFile)) + new BufferedInputStream(new FileInputStream(requestedFile)) - var bytesRead = bis.read (byteArray, 0, byteArray.length) + var bytesRead = bis.read(byteArray, 0, byteArray.length) var alreadyRead = bytesRead while (alreadyRead < requestedFileLen) { - bytesRead = bis.read(byteArray, alreadyRead, + bytesRead = bis.read(byteArray, alreadyRead, (byteArray.length - alreadyRead)) if(bytesRead > 0) { alreadyRead = alreadyRead + bytesRead } } - bis.close + bis.close() // Send - bos.write (byteArray, 0, byteArray.length) - bos.flush + bos.write(byteArray, 0, byteArray.length) + bos.flush() } else { // Close the connection } } catch { - // If something went wrong, e.g., the worker at the other end died etc. + // If something went wrong, e.g., the worker at the other end died etc // then close everything up // Exception can happen if the receiver stops receiving case e: Exception => { - logInfo ("ShuffleServerThread had a " + e) + logInfo("ShuffleServerThread had a " + e) } } finally { - logInfo ("ShuffleServerThread is closing streams and sockets") - ois.close + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() // TODO: Following can cause "java.net.SocketException: Socket closed" - oos.close - bos.close - clientSocket.close + oos.close() + bos.close() + clientSocket.close() } } } diff --git a/src/scala/spark/HttpBlockedLocalFileShuffle.scala b/src/scala/spark/HttpBlockedLocalFileShuffle.scala index cd4e0abceca3635e651d6eaee2db24ce7f3e31c4..be26dabc5f1fdbbad697c96a3ae630b19e9d38f1 100644 --- a/src/scala/spark/HttpBlockedLocalFileShuffle.scala +++ b/src/scala/spark/HttpBlockedLocalFileShuffle.scala @@ -21,7 +21,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} * TODO: Add support for compression when spark.compress is set to true. */ @serializable -class HttpBlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { +class HttpBlockedLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { @transient var totalSplits = 0 @transient var hasSplits = 0 @@ -69,7 +70,8 @@ class HttpBlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging for (i <- 0 until numOutputSplits) { // Open the INDEX file var indexFile: File = - HttpBlockedLocalFileShuffle.getBlockIndexOutputFile(shuffleId, myIndex, i) + HttpBlockedLocalFileShuffle.getBlockIndexOutputFile(shuffleId, + myIndex, i) var indexOut = new ObjectOutputStream(new FileOutputStream(indexFile)) var indexDirty: Boolean = true var alreadyWritten: Long = 0 @@ -88,7 +90,8 @@ class HttpBlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging indexDirty = true // Update the INDEX file if more than blockSize limit has been written - if (file.length - alreadyWritten > HttpBlockedLocalFileShuffle.BlockSize) { + if (file.length - alreadyWritten > + HttpBlockedLocalFileShuffle.BlockSize) { indexOut.writeObject(file.length) indexDirty = false alreadyWritten = file.length @@ -158,6 +161,8 @@ class HttpBlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging // Sleep for a while before creating new threads Thread.sleep(HttpBlockedLocalFileShuffle.MinKnockInterval) } + + threadPool.shutdown() combiners }) } diff --git a/src/scala/spark/HttpParallelLocalFileShuffle.scala b/src/scala/spark/HttpParallelLocalFileShuffle.scala index 89daccf7a732a6b5123eb792cf49745f28323e0c..b9baeaee30a2ed3386017d0b9c713a4e50ad85fa 100644 --- a/src/scala/spark/HttpParallelLocalFileShuffle.scala +++ b/src/scala/spark/HttpParallelLocalFileShuffle.scala @@ -16,7 +16,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} * TODO: Add support for compression when spark.compress is set to true. */ @serializable -class HttpParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { +class HttpParallelLocalFileShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { @transient var totalSplits = 0 @transient var hasSplits = 0 @@ -58,7 +59,8 @@ class HttpParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Loggin } for (i <- 0 until numOutputSplits) { - val file = HttpParallelLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + val file = + HttpParallelLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) val writeStartTime = System.currentTimeMillis logInfo("BEGIN WRITE: " + file) val out = new ObjectOutputStream(new FileOutputStream(file)) @@ -115,6 +117,8 @@ class HttpParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Loggin // Sleep for a while before creating new threads Thread.sleep(HttpParallelLocalFileShuffle.MinKnockInterval) } + + threadPool.shutdown() combiners }) }