diff --git a/src/scala/spark/LocalFileShuffle.scala b/src/scala/spark/LocalFileShuffle.scala index eb9bd40698351c843ab31bf50e863bf759831b4a..79a3ab74dd9098f99fd1857b65482a338a904d5a 100644 --- a/src/scala/spark/LocalFileShuffle.scala +++ b/src/scala/spark/LocalFileShuffle.scala @@ -2,7 +2,7 @@ package spark import java.io._ import java.net._ -import java.util.{Timer, TimerTask, UUID} +import java.util.{BitSet, Random, Timer, TimerTask, UUID} import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory} @@ -15,6 +15,11 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} */ @serializable class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + @transient var hasSplitsBitVector: BitSet = null + @transient var combiners: HashMap[K,C] = null + override def compute(input: RDD[(K, V)], numOutputSplits: Int, createCombiner: V => C, @@ -71,11 +76,20 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { // Return an RDD that does each of the merges for a given partition val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) return indexes.flatMap((myId: Int) => { - val combiners = new HashMap[K, C] - for ((serverAddress, serverPort, inputId) <- splitsByUri) { + totalSplits = splitsByUri.size + hasSplitsBitVector = new BitSet (totalSplits) + combiners = new HashMap[K, C] + + while (hasSplits < totalSplits) { + // Select a random split to pull + val splitIndex = selectRandomSplit + val (serverAddress, serverPort, inputId) = + splitsByUri (splitIndex) + val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId) - val shuffleClient = new ShuffleClient(serverAddress, serverPort, requestPath) + val shuffleClient = + new ShuffleClient(serverAddress, serverPort, requestPath) val readStartTime = System.currentTimeMillis logInfo ("BEGIN READ: " + requestPath) shuffleClient.start @@ -96,6 +110,11 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } inputStream.close + hasSplits += 1 + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set (splitIndex) + } + logInfo ("END READ: " + requestPath) val readTime = (System.currentTimeMillis - readStartTime) logInfo ("Reading " + requestPath + " took " + readTime + " millis.") @@ -103,6 +122,107 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { combiners }) } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + hasSplitsBitVector.synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(LocalFileShuffle.ranGen.nextInt (requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleClient (hostAddress: String, listenPort: Int, requestPath: String) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + var byteArray: Array[Byte] = null + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections + } + } + + var timeOutTimer = new Timer + // TODO: Set wait timer + // TODO: If its too small, things FAIL + timeOutTimer.schedule (timeOutTask, 10000) + + logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath)) + + try { + // Connect to the source + peerSocketToSource = new Socket (hostAddress, listenPort) + oosSource = + new ObjectOutputStream (peerSocketToSource.getOutputStream) + oosSource.flush + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream (isSource) + + // Send the request + oosSource.writeObject(requestPath) + + // Receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo ("Received requestedFileLen = " + requestedFileLen) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel + + // Receive the file + if (requestedFileLen != -1) { + byteArray = new Array[Byte] (requestedFileLen) + var bytesRead = isSource.read (byteArray, 0, byteArray.length) + var alreadyRead = bytesRead + + while (alreadyRead < requestedFileLen) { + bytesRead = isSource.read(byteArray, alreadyRead, + (byteArray.length - alreadyRead)) + if(bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + } else { + throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath) + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo ("ShuffleClient had a " + e) + } + } finally { + cleanUpConnections + } + } + + private def cleanUpConnections: Unit = { + if (oisSource != null) { + oisSource.close + } + if (oosSource != null) { + oosSource.close + } + if (peerSocketToSource != null) { + peerSocketToSource.close + } + } + } } object LocalFileShuffle extends Logging { @@ -116,6 +236,9 @@ object LocalFileShuffle extends Logging { private var serverAddress = InetAddress.getLocalHost.getHostAddress private var serverPort: Int = -1 + // Random number generator + var ranGen = new Random + private def initializeIfNeeded() = synchronized { if (!initialized) { // TODO: localDir should be created by some mechanism common to Spark @@ -203,7 +326,7 @@ object LocalFileShuffle extends Logging { serverPort = serverSocket.getLocalPort logInfo ("ShuffleServer started with " + serverSocket) - logInfo ("Local URI: " + serverAddress + ":" + serverPort) + logInfo ("Local URI: http://" + serverAddress + ":" + serverPort) try { while (true) { @@ -239,6 +362,8 @@ object LocalFileShuffle extends Logging { extends Thread with Logging { private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] os.flush + private val bos = new BufferedOutputStream (os) + bos.flush private val oos = new ObjectOutputStream (os) oos.flush private val ois = new ObjectInputStream (clientSocket.getInputStream) @@ -286,13 +411,12 @@ object LocalFileShuffle extends Logging { if(bytesRead > 0) { alreadyRead = alreadyRead + bytesRead } - } - + } bis.close // Send - os.write (byteArray, 0, byteArray.length) - os.flush + bos.write (byteArray, 0, byteArray.length) + bos.flush } else { // Close the connection } @@ -307,92 +431,11 @@ object LocalFileShuffle extends Logging { logInfo ("ShuffleServerThread is closing streams and sockets") ois.close // TODO: Following can cause "java.net.SocketException: Socket closed" - oos.close + oos.close + bos.close clientSocket.close } } } } } - -class ShuffleClient (hostAddress: String, listenPort: Int, requestPath: String) -extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - var byteArray: Array[Byte] = null - - override def run: Unit = { - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run: Unit = { - cleanUpConnections - } - } - - var timeOutTimer = new Timer - // TODO: Set wait timer - timeOutTimer.schedule (timeOutTask, 1000) - - logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath)) - - try { - // Connect to the source - peerSocketToSource = new Socket (hostAddress, listenPort) - oosSource = - new ObjectOutputStream (peerSocketToSource.getOutputStream) - oosSource.flush - var isSource = peerSocketToSource.getInputStream - oisSource = new ObjectInputStream (isSource) - - // Send the request - oosSource.writeObject(requestPath) - - // Receive the length of the requested file - var requestedFileLen = oisSource.readObject.asInstanceOf[Int] - logInfo ("Received requestedFileLen = " + requestedFileLen) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel - - // Receive the file - if (requestedFileLen != -1) { - byteArray = new Array[Byte] (requestedFileLen) - var bytesRead = isSource.read (byteArray, 0, byteArray.length) - var alreadyRead = bytesRead - - while (alreadyRead < requestedFileLen) { - bytesRead = isSource.read(byteArray, alreadyRead, - (byteArray.length - alreadyRead)) - if(bytesRead > 0) { - alreadyRead = alreadyRead + bytesRead - } - } - } else { - throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath) - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logInfo ("ShuffleClient had a " + e) - } - } finally { - cleanUpConnections - } - } - - private def cleanUpConnections: Unit = { - if (oisSource != null) { - oisSource.close - } - if (oosSource != null) { - oosSource.close - } - if (peerSocketToSource != null) { - peerSocketToSource.close - } - } -}