From 7ba34bc007ec10d12b2a871749f32232cdbc0d9c Mon Sep 17 00:00:00 2001
From: Charles Reiss <charles@eecs.berkeley.edu>
Date: Mon, 14 Jan 2013 15:24:08 -0800
Subject: [PATCH] Additional tests for MapOutputTracker.

---
 .../scala/spark/MapOutputTrackerSuite.scala   | 82 ++++++++++++++++++-
 1 file changed, 80 insertions(+), 2 deletions(-)

diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 5b4b198960..6c6f82e274 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -1,12 +1,18 @@
 package spark
 
 import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
 
 import akka.actor._
 import spark.scheduler.MapStatus
 import spark.storage.BlockManagerId
+import spark.util.AkkaUtils
 
-class MapOutputTrackerSuite extends FunSuite {
+class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter {
+  after {
+    System.clearProperty("spark.master.port")
+  }
+ 
   test("compressSize") {
     assert(MapOutputTracker.compressSize(0L) === 0)
     assert(MapOutputTracker.compressSize(1L) === 1)
@@ -71,6 +77,78 @@ class MapOutputTrackerSuite extends FunSuite {
     // The remaining reduce task might try to grab the output dispite the shuffle failure;
     // this should cause it to fail, and the scheduler will ignore the failure due to the
     // stage already being aborted.
-    intercept[Exception] { tracker.getServerStatuses(10, 1) }
+    intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+  }
+
+  test("remote fetch") {
+    val (actorSystem, boundPort) =
+      AkkaUtils.createActorSystem("test", "localhost", 0)
+    System.setProperty("spark.master.port", boundPort.toString)
+    val masterTracker = new MapOutputTracker(actorSystem, true)
+    val slaveTracker = new MapOutputTracker(actorSystem, false)
+    masterTracker.registerShuffle(10, 1)
+    masterTracker.incrementGeneration()
+    slaveTracker.updateGeneration(masterTracker.getGeneration)
+    intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+    val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+    val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+    masterTracker.registerMapOutput(10, 0, new MapStatus(
+      new BlockManagerId("hostA", 1000), Array(compressedSize1000)))
+    masterTracker.incrementGeneration()
+    slaveTracker.updateGeneration(masterTracker.getGeneration)
+    assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+           Seq((new BlockManagerId("hostA", 1000), size1000)))
+
+    masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+    masterTracker.incrementGeneration()
+    slaveTracker.updateGeneration(masterTracker.getGeneration)
+    intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+  }
+
+  test("simulatenous fetch fails") {
+    val dummyActorSystem = ActorSystem("testDummy")
+    val dummyTracker = new MapOutputTracker(dummyActorSystem, true)
+    dummyTracker.registerShuffle(10, 1)
+    // val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+    // val size100 = MapOutputTracker.decompressSize(compressedSize1000)
+    // dummyTracker.registerMapOutput(10, 0, new MapStatus(
+    //   new BlockManagerId("hostA", 1000), Array(compressedSize1000)))
+    val serializedMessage = dummyTracker.getSerializedLocations(10)
+
+    val (actorSystem, boundPort) =
+      AkkaUtils.createActorSystem("test", "localhost", 0)
+    System.setProperty("spark.master.port", boundPort.toString)
+    val delayResponseLock = new java.lang.Object
+    val delayResponseActor = actorSystem.actorOf(Props(new Actor {
+      override def receive = {
+        case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
+          delayResponseLock.synchronized {
+            sender ! serializedMessage
+          }
+      }
+    }), name = "MapOutputTracker")
+    val slaveTracker = new MapOutputTracker(actorSystem, false)
+    var firstFailed = false
+    var secondFailed = false
+    val firstFetch = new Thread {
+      override def run() {
+        intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+        firstFailed = true
+      }
+    }
+    val secondFetch = new Thread {
+      override def run() {
+        intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+        secondFailed = true
+      }
+    }
+    delayResponseLock.synchronized {
+      firstFetch.start
+      secondFetch.start
+    }
+    firstFetch.join
+    secondFetch.join
+    assert(firstFailed && secondFailed)
   }
 }
-- 
GitLab