diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index da07ce2c6ea4904cd6087327293eb58c9637a2ed..1b65926f5c749a33b8d5b63f921de02368e7f662 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -67,7 +67,7 @@ private[spark] object TaskLocation { if (hstr.equals(str)) { new HostTaskLocation(str) } else { - new HostTaskLocation(hstr) + new HDFSCacheTaskLocation(hstr) } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index f0eadf240943e8764e4f8a538461b1378081c8f1..695523cc8aa3a4d6a6396a54c29e269059b9d5e1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -759,9 +759,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) val taskSet = FakeTask.createTaskSet(3, - Seq(HostTaskLocation("host1")), - Seq(HostTaskLocation("host2")), - Seq(HDFSCacheTaskLocation("host3"))) + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("hdfs_cache_host3"))) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) @@ -776,6 +776,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.myLocalityLevels.sameElements(Array(ANY))) } + test("Test TaskLocation for different host type.") { + assert(TaskLocation("host1") === HostTaskLocation("host1")) + assert(TaskLocation("hdfs_cache_host1") === HDFSCacheTaskLocation("host1")) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)