diff --git a/src/scala/spark/LocalFileShuffle.scala b/src/scala/spark/LocalFileShuffle.scala index 79a3ab74dd9098f99fd1857b65482a338a904d5a..6e69f48868a6e4a6edb4f7cb72042269285d075a 100644 --- a/src/scala/spark/LocalFileShuffle.scala +++ b/src/scala/spark/LocalFileShuffle.scala @@ -88,28 +88,13 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId) - val shuffleClient = - new ShuffleClient(serverAddress, serverPort, requestPath) + val shuffleClient = new ShuffleClient(serverAddress, serverPort, + requestPath, mergeCombiners) val readStartTime = System.currentTimeMillis logInfo ("BEGIN READ: " + requestPath) shuffleClient.start shuffleClient.join - val inputStream = new ObjectInputStream ( - new ByteArrayInputStream(shuffleClient.byteArray)) - try { - while (true) { - val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } - } - } catch { - case e: EOFException => {} - } - inputStream.close - hasSplits += 1 hasSplitsBitVector.synchronized { hasSplitsBitVector.set (splitIndex) @@ -141,13 +126,14 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } } - class ShuffleClient (hostAddress: String, listenPort: Int, requestPath: String) + class ShuffleClient (hostAddress: String, listenPort: Int, + requestPath: String, mergeCombiners: (C, C) => C) extends Thread with Logging { private var peerSocketToSource: Socket = null private var oosSource: ObjectOutputStream = null private var oisSource: ObjectInputStream = null - var byteArray: Array[Byte] = null + private var byteArray: Array[Byte] = null override def run: Unit = { // Setup the timeout mechanism @@ -160,7 +146,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { var timeOutTimer = new Timer // TODO: Set wait timer // TODO: If its too small, things FAIL - timeOutTimer.schedule (timeOutTask, 10000) + timeOutTimer.schedule (timeOutTask, 100000) logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath)) @@ -195,7 +181,27 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { if(bytesRead > 0) { alreadyRead = alreadyRead + bytesRead } - } + } + + // Now add this to combiners + combiners.synchronized { + val inputStream = new ObjectInputStream ( + new ByteArrayInputStream(byteArray)) + try { + while (true) { + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: Exception => { + logInfo ("Merging to combiners had a " + e) + } + } + inputStream.close + } } else { throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath) }