diff --git a/src/scala/spark/BlockedLocalFileShuffle.scala b/src/scala/spark/BlockedLocalFileShuffle.scala index bd88a263b9925200ed8dd39442c4905afac688f4..592ec5f6effb27956deec2be421782cdd88c7c5d 100644 --- a/src/scala/spark/BlockedLocalFileShuffle.scala +++ b/src/scala/spark/BlockedLocalFileShuffle.scala @@ -26,6 +26,7 @@ class BlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { @transient var totalSplits = 0 @transient var hasSplits = 0 + @transient var blocksInSplit: Array[ArrayBuffer[Long]] = null @transient var totalBlocksInSplit: Array[Int] = null @transient var hasBlocksInSplit: Array[Int] = null @@ -67,55 +68,45 @@ class BlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } for (i <- 0 until numOutputSplits) { - var blockNum = 0 - var isDirty = false - var file: File = null - var out: ObjectOutputStream = null - - var writeStartTime: Long = 0 + // Open the INDEX file + var indexFile: File = + BlockedLocalFileShuffle.getBlockIndexOutputFile(shuffleId, myIndex, i) + var indexOut = new ObjectOutputStream(new FileOutputStream(indexFile)) + var indexDirty: Boolean = true + var alreadyWritten: Long = 0 + + // Open the actual file + var file: File = + BlockedLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + val out = new ObjectOutputStream(new FileOutputStream(file)) + val writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + file) + buckets(i).foreach(pair => { - // Open a new file if necessary - if (!isDirty) { - file = BlockedLocalFileShuffle.getOutputFile(shuffleId, myIndex, i, - blockNum) - writeStartTime = System.currentTimeMillis - logInfo("BEGIN WRITE: " + file) - - out = new ObjectOutputStream(new FileOutputStream(file)) - } - out.writeObject(pair) out.flush() - isDirty = true + indexDirty = true - // Close the old file if has crossed the blockSize limit - if (file.length > BlockedLocalFileShuffle.BlockSize) { - out.close() - logInfo("END WRITE: " + file) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 - isDirty = false + // Update the INDEX file if more than blockSize limit has been written + if (file.length - alreadyWritten > BlockedLocalFileShuffle.BlockSize) { + indexOut.writeObject(file.length) + indexDirty = false + alreadyWritten = file.length } }) - if (isDirty) { - out.close() - logInfo("END WRITE: " + file) - val writeTime = System.currentTimeMillis - writeStartTime - logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") - - blockNum = blockNum + 1 + // Write down the last range if it was not written + if (indexDirty) { + indexOut.writeObject(file.length) } - // Write the BLOCKNUM file - file = - BlockedLocalFileShuffle.getBlockNumOutputFile(shuffleId, myIndex, i) - out = new ObjectOutputStream(new FileOutputStream(file)) - out.writeObject(blockNum) out.close() + indexOut.close() + + logInfo("END WRITE: " + file) + val writeTime = (System.currentTimeMillis - writeStartTime) + logInfo("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") } (myIndex, BlockedLocalFileShuffle.serverUri) @@ -129,6 +120,7 @@ class BlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { totalSplits = outputLocs.size hasSplits = 0 + blocksInSplit = Array.tabulate(totalSplits)(_ => new ArrayBuffer[Long]) totalBlocksInSplit = Array.tabulate(totalSplits)(_ => -1) hasBlocksInSplit = Array.tabulate(totalSplits)(_ => 0) @@ -198,25 +190,50 @@ class BlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { override def run: Unit = { try { - // TODO: Everything will break if BLOCKNUM is not correctly received - // First get the BLOCKNUM file if totalBlocksInSplit(inputId) is unknown + // First get the INDEX file if totalBlocksInSplit(inputId) is unknown if (totalBlocksInSplit(inputId) == -1) { - val url = "%s/shuffle/%d/%d/BLOCKNUM-%d".format(serverUri, shuffleId, + val url = "%s/shuffle/%d/%d/INDEX-%d".format(serverUri, shuffleId, inputId, myId) val inputStream = new ObjectInputStream(new URL(url).openStream()) - totalBlocksInSplit(inputId) = - inputStream.readObject().asInstanceOf[Int] + + try { + while (true) { + blocksInSplit(inputId) += + inputStream.readObject().asInstanceOf[Long] + } + } catch { + case e: EOFException => {} + } + + totalBlocksInSplit(inputId) = blocksInSplit(inputId).size inputStream.close() } - val url = - "%s/shuffle/%d/%d/%d-%d".format(serverUri, shuffleId, inputId, - myId, hasBlocksInSplit(inputId)) + val urlString = + "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, inputId, myId) + val url = new URL(urlString) + val httpConnection = + url.openConnection().asInstanceOf[HttpURLConnection] + // Set the range to download + val blockStartsAt = hasBlocksInSplit(inputId) match { + case 0 => 0 + case _ => blocksInSplit(inputId)(hasBlocksInSplit(inputId) - 1) + 1 + } + val blockEndsAt = blocksInSplit(inputId)(hasBlocksInSplit(inputId)) + httpConnection.setRequestProperty("Range", + "bytes=" + blockStartsAt + "-" + blockEndsAt) + + // Connect to the server + httpConnection.connect() + + val urStringWithRange = + urlString + "[%d:%d]".format(blockStartsAt, blockEndsAt) val readStartTime = System.currentTimeMillis - logInfo("BEGIN READ: " + url) + logInfo("BEGIN READ: " + urStringWithRange) - val inputStream = new ObjectInputStream(new URL(url).openStream()) + // Receive the block + val inputStream = new ObjectInputStream(httpConnection.getInputStream()) try { while (true) { val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] @@ -232,9 +249,12 @@ class BlockedLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } inputStream.close() - logInfo("END READ: " + url) + logInfo("END READ: " + urStringWithRange) val readTime = System.currentTimeMillis - readStartTime - logInfo("Reading " + url + " took " + readTime + " millis.") + logInfo("Reading " + urStringWithRange + " took " + readTime + " millis.") + + // Disconnect + httpConnection.disconnect() // Reception completed. Update stats. hasBlocksInSplit(inputId) = hasBlocksInSplit(inputId) + 1 @@ -366,21 +386,20 @@ object BlockedLocalFileShuffle extends Logging { def MaxConnections = MaxConnections_ - def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int, - blockId: Int): File = { + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { initializeIfNeeded() val dir = new File(shuffleDir, shuffleId + "/" + inputId) dir.mkdirs() - val file = new File(dir, "%d-%d".format(outputId, blockId)) + val file = new File(dir, "" + outputId) return file } - def getBlockNumOutputFile(shuffleId: Long, inputId: Int, + def getBlockIndexOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { initializeIfNeeded() val dir = new File(shuffleDir, shuffleId + "/" + inputId) dir.mkdirs() - val file = new File(dir, "BLOCKNUM-" + outputId) + val file = new File(dir, "INDEX-" + outputId) return file }