diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
index da07ce2c6ea4904cd6087327293eb58c9637a2ed..1b65926f5c749a33b8d5b63f921de02368e7f662 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
@@ -67,7 +67,7 @@ private[spark] object TaskLocation {
     if (hstr.equals(str)) {
       new HostTaskLocation(str)
     } else {
-      new HostTaskLocation(hstr)
+      new HDFSCacheTaskLocation(hstr)
     }
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index f0eadf240943e8764e4f8a538461b1378081c8f1..695523cc8aa3a4d6a6396a54c29e269059b9d5e1 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -759,9 +759,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     val sched = new FakeTaskScheduler(sc,
         ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
     val taskSet = FakeTask.createTaskSet(3,
-      Seq(HostTaskLocation("host1")),
-      Seq(HostTaskLocation("host2")),
-      Seq(HDFSCacheTaskLocation("host3")))
+      Seq(TaskLocation("host1")),
+      Seq(TaskLocation("host2")),
+      Seq(TaskLocation("hdfs_cache_host3")))
     val clock = new ManualClock
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
     assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
@@ -776,6 +776,11 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
     assert(manager.myLocalityLevels.sameElements(Array(ANY)))
   }
 
+  test("Test TaskLocation for different host type.") {
+    assert(TaskLocation("host1") === HostTaskLocation("host1"))
+    assert(TaskLocation("hdfs_cache_host1") === HDFSCacheTaskLocation("host1"))
+  }
+
   def createTaskResult(id: Int): DirectTaskResult[Int] = {
     val valueSer = SparkEnv.get.serializer.newInstance()
     new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)