diff --git a/conf/java-opts b/conf/java-opts index af2b51124a9c24f07ceb3fd286b73addf7637bbb..3de92cf38293ce8a387b8781a07cc62e9b077e54 100644 --- a/conf/java-opts +++ b/conf/java-opts @@ -1 +1 @@ --Dspark.shuffle.class=spark.CustomParallelLocalFileShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=1000 -Dspark.parallelLocalFileShuffle.maxKnockInterval=5000 +-Dspark.shuffle.class=spark.HttpBlockedLocalFileShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=1000 -Dspark.parallelLocalFileShuffle.maxKnockInterval=5000 diff --git a/src/scala/spark/CustomParallelLocalFileShuffle.scala b/src/scala/spark/CustomParallelLocalFileShuffle.scala index 2fa32473835b5bf0953ba63d9b27ad6ed1c7af84..a6d8845e5b2ffd9e7a703d9149cc527fcfe223cd 100644 --- a/src/scala/spark/CustomParallelLocalFileShuffle.scala +++ b/src/scala/spark/CustomParallelLocalFileShuffle.scala @@ -154,7 +154,7 @@ extends Shuffle[K, V, C] with Logging { } } - class ShuffleConsumer(mergeCombiners: (C, C) => C) + class ShuffleConsumer(mergeCombiners: (C, C) => C) extends Thread with Logging { override def run: Unit = { // Run until all splits are here @@ -272,7 +272,7 @@ extends Shuffle[K, V, C] with Logging { } } - // NOTE: Update of bitVectors are now done by the consumer. + // NOTE: Update of bitVectors are now done by the consumer receptionSucceeded = true diff --git a/src/scala/spark/HttpBlockedLocalFileShuffle.scala b/src/scala/spark/HttpBlockedLocalFileShuffle.scala index cc927b6e9e203d8ea42feb5c48e3750ff56b76b5..ace82c1c658b9e3c4b8eae9ed3ab195a2272101d 100644 --- a/src/scala/spark/HttpBlockedLocalFileShuffle.scala +++ b/src/scala/spark/HttpBlockedLocalFileShuffle.scala @@ -4,7 +4,7 @@ import java.io._ import java.net._ import java.util.{BitSet, Random, Timer, TimerTask, UUID} import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory} +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -33,6 +33,7 @@ extends Shuffle[K, V, C] with Logging { @transient var hasSplitsBitVector: BitSet = null @transient var splitsInRequestBitVector: BitSet = null + @transient var receivedData: LinkedBlockingQueue[(Int, Int, Array[Byte])] = null @transient var combiners: HashMap[K,C] = null override def compute(input: RDD[(K, V)], @@ -129,8 +130,15 @@ extends Shuffle[K, V, C] with Logging { hasSplitsBitVector = new BitSet(totalSplits) splitsInRequestBitVector = new BitSet(totalSplits) + receivedData = new LinkedBlockingQueue[(Int, Int, Array[Byte])] combiners = new HashMap[K, C] + // Start consumer + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + var threadPool = HttpBlockedLocalFileShuffle.newDaemonFixedThreadPool( HttpBlockedLocalFileShuffle.MaxRxConnections) @@ -186,6 +194,62 @@ extends Shuffle[K, V, C] with Logging { } } + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (hasSplits < totalSplits) { + var inputId = -1 + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempTuple = + receivedData.take().asInstanceOf[(Int, Int, Array[Byte])] + inputId = tempTuple._1 + splitIndex = tempTuple._2 + recvByteArray = tempTuple._3 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + + // Consumption completed. Update stats. + hasBlocksInSplit(inputId) = hasBlocksInSplit(inputId) + 1 + + // Split has been received only if all the blocks have been received + if (hasBlocksInSplit(inputId) == totalBlocksInSplit(inputId)) { + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + } + hasSplits += 1 + } + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + } + } + class ShuffleClient(serverUri: String, shuffleId: Int, inputId: Int, myId: Int, splitIndex: Int, mergeCombiners: (C, C) => C) @@ -212,7 +276,8 @@ extends Shuffle[K, V, C] with Logging { totalBlocksInSplit(inputId) = blocksInSplit(inputId).size inputStream.close() } - + + // Open connection val urlString = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, inputId, myId) val url = new URL(urlString) @@ -236,47 +301,40 @@ extends Shuffle[K, V, C] with Logging { val readStartTime = System.currentTimeMillis logInfo("BEGIN READ: " + urStringWithRange) - // Receive the block - val inputStream = new ObjectInputStream(httpConnection.getInputStream()) - try { - while (true) { - val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] - combiners.synchronized { - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } + // Receive data in an Array[Byte] + val requestedFileLen: Int = (blockEndsAt - blockStartsAt).toInt + 1 + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + val isSource = httpConnection.getInputStream() + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead } - } catch { - case e: EOFException => {} - } - inputStream.close() - - logInfo("END READ: " + urStringWithRange) - val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading " + urStringWithRange + " took " + readTime + " millis.") - + } + // Disconnect httpConnection.disconnect() - // Reception completed. Update stats. - hasBlocksInSplit(inputId) = hasBlocksInSplit(inputId) + 1 - - // Split has been received only if all the blocks have been received - if (hasBlocksInSplit(inputId) == totalBlocksInSplit(inputId)) { - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set(splitIndex) + // Make it available to the consumer + try { + receivedData.put((inputId, splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") } - hasSplits += 1 - } - - // We have received splitIndex - splitsInRequestBitVector.synchronized { - splitsInRequestBitVector.set(splitIndex, false) } + + // NOTE: Update of bitVectors are now done by the consumer receptionSucceeded = true + + logInfo("END READ: " + urStringWithRange) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading " + urStringWithRange + " took " + readTime + " millis.") } catch { // EOFException is expected to happen because sender can break // connection due to timeout