diff --git a/src/scala/spark/LocalFileShuffle.scala b/src/scala/spark/LocalFileShuffle.scala index 6e69f48868a6e4a6edb4f7cb72042269285d075a..6ddd6612883e8d445891df079485cae9766ab0e7 100644 --- a/src/scala/spark/LocalFileShuffle.scala +++ b/src/scala/spark/LocalFileShuffle.scala @@ -18,8 +18,10 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { @transient var totalSplits = 0 @transient var hasSplits = 0 @transient var hasSplitsBitVector: BitSet = null - @transient var combiners: HashMap[K,C] = null + @transient var splitsInRequestBitVector: BitSet = null + @transient var combiners: HashMap[K,C] = null + override def compute(input: RDD[(K, V)], numOutputSplits: Int, createCombiner: V => C, @@ -77,33 +79,44 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) return indexes.flatMap((myId: Int) => { totalSplits = splitsByUri.size + hasSplits = 0 hasSplitsBitVector = new BitSet (totalSplits) + splitsInRequestBitVector = new BitSet (totalSplits) combiners = new HashMap[K, C] + // TODO: Fix config param + var threadPool = LocalFileShuffle.newDaemonFixedThreadPool (2) + while (hasSplits < totalSplits) { - // Select a random split to pull - val splitIndex = selectRandomSplit - val (serverAddress, serverPort, inputId) = - splitsByUri (splitIndex) + // TODO: + var numThreadsToCreate = + Math.min (totalSplits, 2) - threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (serverAddress, serverPort, inputId) = splitsByUri (splitIndex) + val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId) - val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId) - - val shuffleClient = new ShuffleClient(serverAddress, serverPort, - requestPath, mergeCombiners) - val readStartTime = System.currentTimeMillis - logInfo ("BEGIN READ: " + requestPath) - shuffleClient.start - shuffleClient.join - - hasSplits += 1 - hasSplitsBitVector.synchronized { - hasSplitsBitVector.set (splitIndex) + threadPool.execute (new ShuffleClient (splitIndex, serverAddress, + serverPort, requestPath, mergeCombiners)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set (splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 } - logInfo ("END READ: " + requestPath) - val readTime = (System.currentTimeMillis - readStartTime) - logInfo ("Reading " + requestPath + " took " + readTime + " millis.") + // TODO: + Thread.sleep (1000) } + + threadPool.shutdown combiners }) } @@ -111,9 +124,9 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { def selectRandomSplit: Int = { var requiredSplits = new ArrayBuffer[Int] - hasSplitsBitVector.synchronized { + synchronized { for (i <- 0 until totalSplits) { - if (!hasSplitsBitVector.get(i)) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { requiredSplits += i } } @@ -126,7 +139,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } } - class ShuffleClient (hostAddress: String, listenPort: Int, + class ShuffleClient (splitIndex: Int, hostAddress: String, listenPort: Int, requestPath: String, mergeCombiners: (C, C) => C) extends Thread with Logging { private var peerSocketToSource: Socket = null @@ -134,8 +147,13 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { private var oisSource: ObjectInputStream = null private var byteArray: Array[Byte] = null + + private var receptionSucceeded = false override def run: Unit = { + val readStartTime = System.currentTimeMillis + logInfo ("BEGIN READ: " + requestPath) + // Setup the timeout mechanism var timeOutTask = new TimerTask { override def run: Unit = { @@ -146,7 +164,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { var timeOutTimer = new Timer // TODO: Set wait timer // TODO: If its too small, things FAIL - timeOutTimer.schedule (timeOutTask, 100000) + timeOutTimer.schedule (timeOutTask, 5000) logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath)) @@ -182,26 +200,45 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { alreadyRead = alreadyRead + bytesRead } } + + logInfo ("Finished reading " + requestPath) // Now add this to combiners - combiners.synchronized { - val inputStream = new ObjectInputStream ( - new ByteArrayInputStream(byteArray)) - try { - while (true) { - val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + val inputStream = new ObjectInputStream ( + new ByteArrayInputStream(byteArray)) + try{ + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners.synchronized { combiners(k) = combiners.get(k) match { case Some(oldC) => mergeCombiners(oldC, c) case None => c } } - } catch { - case e: Exception => { - logInfo ("Merging to combiners had a " + e) - } } - inputStream.close - } + } catch { + case e: EOFException => { } + } + inputStream.close + + logInfo ("Finished combining " + requestPath) + + // Reception completed. Update stats. + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set (splitIndex) + } + hasSplits += 1 + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set (splitIndex, false) + } + + receptionSucceeded = true + + logInfo ("END READ: " + requestPath) + val readTime = (System.currentTimeMillis - readStartTime) + logInfo ("Reading " + requestPath + " took " + readTime + " millis.") } else { throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath) } @@ -213,6 +250,12 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { logInfo ("ShuffleClient had a " + e) } } finally { + // If reception failed, unset for future retry + if (!receptionSucceeded) { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set (splitIndex, false) + } + } cleanUpConnections } }