diff --git a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala index 42dcb88ba44257faf6a2ac4ec9a4dc2c28573020..7a408ac1500ec11e1745119e716ea7b8deb192a2 100644 --- a/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala +++ b/src/scala/spark/TrackedCustomParallelLocalFileShuffle.scala @@ -380,6 +380,9 @@ extends Shuffle[K, V, C] with Logging { private var oisSource: ObjectInputStream = null private var receptionSucceeded = false + + // Make sure that multiple messages don't go to the tracker + private var alreadySentLeavingNotification = false override def run: Unit = { // Setup the timeout mechanism @@ -476,45 +479,50 @@ extends Shuffle[K, V, C] with Logging { } // Connect to the tracker and update its stats - private def sendLeavingNotification(): Unit = { - val clientSocketToTracker = - new Socket(TrackedCustomParallelLocalFileShuffle.MasterHostAddress, - TrackedCustomParallelLocalFileShuffle.MasterTrackerPort) - val oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - val oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - try { - // Send intention - oosTracker.writeObject( - TrackedCustomParallelLocalFileShuffle.ReducerLeaving) - oosTracker.flush() - - // Send reducerSplitInfo - oosTracker.writeObject(getLocalSplitInfo) + private def sendLeavingNotification(): Unit = synchronized { + if (!alreadySentLeavingNotification) { + val clientSocketToTracker = + new Socket(TrackedCustomParallelLocalFileShuffle.MasterHostAddress, + TrackedCustomParallelLocalFileShuffle.MasterTrackerPort) + val oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) oosTracker.flush() - - // Send serverSplitInfo so that tracker can update its stats - oosTracker.writeObject(splitIndex) - oosTracker.flush() - - // Receive ACK. No need to do anything with that - oisTracker.readObject.asInstanceOf[Int] - } catch { - case e: Exception => { - logInfo("sendLeavingNotification had a " + e) + val oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + try { + // Send intention + oosTracker.writeObject( + TrackedCustomParallelLocalFileShuffle.ReducerLeaving) + oosTracker.flush() + + // Send reducerSplitInfo + oosTracker.writeObject(getLocalSplitInfo) + oosTracker.flush() + + // Send serverSplitInfo so that tracker can update its stats + oosTracker.writeObject(splitIndex) + oosTracker.flush() + + // Receive ACK. No need to do anything with that + oisTracker.readObject.asInstanceOf[Int] + + // Now update sentLeavingNotifacation + alreadySentLeavingNotification = true + } catch { + case e: Exception => { + logInfo("sendLeavingNotification had a " + e) + } + } finally { + oisTracker.close() + oosTracker.close() + clientSocketToTracker.close() } - } finally { - oisTracker.close() - oosTracker.close() - clientSocketToTracker.close() - } + } } private def cleanUp(): Unit = { - // Update tracker stats first. + // Update tracker stats first sendLeavingNotification() // Clean up the connections to the mapper