diff --git a/conf/java-opts b/conf/java-opts index b851798d093b9f715fafad5240af415ef9d5015f..4649061fcb4672f1a66459ea87e88f140000198b 100644 --- a/conf/java-opts +++ b/conf/java-opts @@ -1 +1 @@ --Dspark.shuffle.class=spark.HttpParallelLocalFileShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxKnockInterval=2000 +-Dspark.shuffle.class=spark.CustomBlockedLocalFileShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxKnockInterval=2000 diff --git a/src/scala/spark/CustomBlockedLocalFileShuffle.scala b/src/scala/spark/CustomBlockedLocalFileShuffle.scala index 319a9d360cf36e75e2539fdf8fd2b95beba65796..75f8c0bffe69040c45bcd3d1951538fdb3953534 100644 --- a/src/scala/spark/CustomBlockedLocalFileShuffle.scala +++ b/src/scala/spark/CustomBlockedLocalFileShuffle.scala @@ -4,7 +4,7 @@ import java.io._ import java.net._ import java.util.{BitSet, Random, Timer, TimerTask, UUID} import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory} +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -32,6 +32,7 @@ extends Shuffle[K, V, C] with Logging { @transient var hasSplitsBitVector: BitSet = null @transient var splitsInRequestBitVector: BitSet = null + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null @transient var combiners: HashMap[K,C] = null override def compute(input: RDD[(K, V)], @@ -135,8 +136,15 @@ extends Shuffle[K, V, C] with Logging { hasSplitsBitVector = new BitSet(totalSplits) splitsInRequestBitVector = new BitSet(totalSplits) + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] combiners = new HashMap[K, C] + // Start consumer + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + var threadPool = CustomBlockedLocalFileShuffle.newDaemonFixedThreadPool( CustomBlockedLocalFileShuffle.MaxRxConnections) @@ -192,6 +200,59 @@ extends Shuffle[K, V, C] with Logging { } } + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (hasSplits < totalSplits) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + + // Consumption completed. Update stats. + hasBlocksInSplit(splitIndex) = hasBlocksInSplit(splitIndex) + 1 + + // Split has been received only if all the blocks have been received + if (hasBlocksInSplit(splitIndex) == totalBlocksInSplit(splitIndex)) { + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + } + hasSplits += 1 + } + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + } + } + class ShuffleClient(serverUri: String, shuffleId: Int, inputId: Int, myId: Int, splitIndex: Int, mergeCombiners: (C, C) => C) @@ -201,60 +262,66 @@ extends Shuffle[K, V, C] with Logging { override def run: Unit = { try { // TODO: Everything will break if BLOCKNUM is not correctly received - // First get the BLOCKNUM file if totalBlocksInSplit(inputId) is unknown - if (totalBlocksInSplit(inputId) == -1) { + // First get BLOCKNUM file if totalBlocksInSplit(splitIndex) is unknown + if (totalBlocksInSplit(splitIndex) == -1) { val url = "%s/shuffle/%d/%d/BLOCKNUM-%d".format(serverUri, shuffleId, inputId, myId) val inputStream = new ObjectInputStream(new URL(url).openStream()) - totalBlocksInSplit(inputId) = + totalBlocksInSplit(splitIndex) = inputStream.readObject().asInstanceOf[Int] inputStream.close() } - val url = + // Open connection + val urlString = "%s/shuffle/%d/%d/%d-%d".format(serverUri, shuffleId, inputId, - myId, hasBlocksInSplit(inputId)) + myId, hasBlocksInSplit(splitIndex)) + val url = new URL(urlString) + val httpConnection = + url.openConnection().asInstanceOf[HttpURLConnection] + + // Connect to the server + httpConnection.connect() + // Receive file length + var requestedFileLen = httpConnection.getContentLength + val readStartTime = System.currentTimeMillis logInfo("BEGIN READ: " + url) - val inputStream = new ObjectInputStream(new URL(url).openStream()) - try { - while (true) { - val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] - combiners.synchronized { - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + val isSource = httpConnection.getInputStream() + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead } + } + + // Disconnect + httpConnection.disconnect() + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) } catch { - case e: EOFException => {} + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } } - inputStream.close() + // NOTE: Update of bitVectors are now done by the consumer + + receptionSucceeded = true + logInfo("END READ: " + url) val readTime = System.currentTimeMillis - readStartTime logInfo("Reading " + url + " took " + readTime + " millis.") - - // Reception completed. Update stats. - hasBlocksInSplit(inputId) = hasBlocksInSplit(inputId) + 1 - - // Split has been received only if all the blocks have been received - if (hasBlocksInSplit(inputId) == totalBlocksInSplit(inputId)) { - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) - } - hasSplits += 1 - } - - // We have received splitIndex - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) - } - - receptionSucceeded = true } catch { // EOFException is expected to happen because sender can break // connection due to timeout