diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
index db638d84c0a1c19f3b273ce4417171bebb062bbe..872fd354273f6c69a0c85cd27f14a7b1445d146d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala
@@ -23,7 +23,6 @@ import scala.collection.JavaConverters._
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.yarn.api.records.{ContainerId, Resource}
 import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
-import org.apache.hadoop.yarn.util.RackResolver
 
 import org.apache.spark.SparkConf
 import org.apache.spark.internal.config._
@@ -83,7 +82,8 @@ private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], rack
 private[yarn] class LocalityPreferredContainerPlacementStrategy(
     val sparkConf: SparkConf,
     val yarnConf: Configuration,
-    val resource: Resource) {
+    val resource: Resource,
+    resolver: SparkRackResolver) {
 
   /**
    * Calculate each container's node locality and rack locality
@@ -139,7 +139,7 @@ private[yarn] class LocalityPreferredContainerPlacementStrategy(
         // still be allocated with new container request.
         val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray
         val racks = hosts.map { h =>
-          RackResolver.resolve(yarnConf, h).getNetworkLocation
+          resolver.resolve(yarnConf, h)
         }.toSet
         containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray)
 
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c711d088f21167551081b901364e7ae4a0c90b61
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.util.RackResolver
+import org.apache.log4j.{Level, Logger}
+
+/**
+ * Wrapper around YARN's [[RackResolver]]. This allows Spark tests to easily override the
+ * default behavior, since YARN's class self-initializes the first time it's called, and
+ * future calls all use the initial configuration.
+ */
+private[yarn] class SparkRackResolver {
+
+  // RackResolver logs an INFO message whenever it resolves a rack, which is way too often.
+  if (Logger.getLogger(classOf[RackResolver]).getLevel == null) {
+    Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN)
+  }
+
+  def resolve(conf: Configuration, hostName: String): String = {
+    RackResolver.resolve(conf, hostName).getNetworkLocation()
+  }
+
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 0b66d1cf08eac468c68a23f52fb5b7d2d7b3b01a..639e564d4684b751ca88b62add262e66e97e5b0f 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.api.records._
 import org.apache.hadoop.yarn.client.api.AMRMClient
 import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
 import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.util.RackResolver
 import org.apache.log4j.{Level, Logger}
 
 import org.apache.spark.{SecurityManager, SparkConf, SparkException}
@@ -65,16 +64,12 @@ private[yarn] class YarnAllocator(
     amClient: AMRMClient[ContainerRequest],
     appAttemptId: ApplicationAttemptId,
     securityMgr: SecurityManager,
-    localResources: Map[String, LocalResource])
+    localResources: Map[String, LocalResource],
+    resolver: SparkRackResolver)
   extends Logging {
 
   import YarnAllocator._
 
-  // RackResolver logs an INFO message whenever it resolves a rack, which is way too often.
-  if (Logger.getLogger(classOf[RackResolver]).getLevel == null) {
-    Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN)
-  }
-
   // Visible for testing.
   val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]
   val allocatedContainerToHostMap = new HashMap[ContainerId, String]
@@ -171,7 +166,7 @@ private[yarn] class YarnAllocator(
 
   // A container placement strategy based on pending tasks' locality preference
   private[yarn] val containerPlacementStrategy =
-    new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource)
+    new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource, resolver)
 
   /**
    * Use a different clock for YarnAllocator. This is mainly used for testing.
@@ -422,7 +417,7 @@ private[yarn] class YarnAllocator(
     // Match remaining by rack
     val remainingAfterRackMatches = new ArrayBuffer[Container]
     for (allocatedContainer <- remainingAfterHostMatches) {
-      val rack = RackResolver.resolve(conf, allocatedContainer.getNodeId.getHost).getNetworkLocation
+      val rack = resolver.resolve(conf, allocatedContainer.getNodeId.getHost)
       matchContainerToRequest(allocatedContainer, rack, containersToUse,
         remainingAfterRackMatches)
     }
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index 53df11eb6602134a3684f2974a5808fea870a149..9e14d63be55e6fb93f729aec8c9f6a4b520454d1 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -75,7 +75,7 @@ private[spark] class YarnRMClient extends Logging {
       registered = true
     }
     new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr,
-      localResources)
+      localResources, new SparkRackResolver())
   }
 
   /**
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala
index fb80ff9f3132264fa2bf6903e347041522308967..b7f25656e49ac702640233147f10a59048c77d55 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/LocalityPlacementStrategySuite.scala
@@ -17,10 +17,9 @@
 
 package org.apache.spark.deploy.yarn
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable.{HashMap, HashSet, Set}
 
-import org.apache.hadoop.fs.CommonConfigurationKeysPublic
-import org.apache.hadoop.net.DNSToSwitchMapping
 import org.apache.hadoop.yarn.api.records._
 import org.apache.hadoop.yarn.conf.YarnConfiguration
 import org.mockito.Mockito._
@@ -51,9 +50,6 @@ class LocalityPlacementStrategySuite extends SparkFunSuite {
 
   private def runTest(): Unit = {
     val yarnConf = new YarnConfiguration()
-    yarnConf.setClass(
-      CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY,
-      classOf[MockResolver], classOf[DNSToSwitchMapping])
 
     // The numbers below have been chosen to balance being large enough to replicate the
     // original issue while not taking too long to run when the issue is fixed. The main
@@ -62,7 +58,7 @@ class LocalityPlacementStrategySuite extends SparkFunSuite {
 
     val resource = Resource.newInstance(8 * 1024, 4)
     val strategy = new LocalityPreferredContainerPlacementStrategy(new SparkConf(),
-      yarnConf, resource)
+      yarnConf, resource, new MockResolver())
 
     val totalTasks = 32 * 1024
     val totalContainers = totalTasks / 16
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index 994dc75d34c304c9e31c54f89f26343a9c2c2026..1b3f438be4d88f230019245a6786a11938c89b41 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -17,10 +17,7 @@
 
 package org.apache.spark.deploy.yarn
 
-import java.util.{Arrays, List => JList}
-
-import org.apache.hadoop.fs.CommonConfigurationKeysPublic
-import org.apache.hadoop.net.DNSToSwitchMapping
+import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.yarn.api.records._
 import org.apache.hadoop.yarn.client.api.AMRMClient
 import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
@@ -36,24 +33,16 @@ import org.apache.spark.rpc.RpcEndpointRef
 import org.apache.spark.scheduler.SplitInfo
 import org.apache.spark.util.ManualClock
 
-class MockResolver extends DNSToSwitchMapping {
+class MockResolver extends SparkRackResolver {
 
-  override def resolve(names: JList[String]): JList[String] = {
-    if (names.size > 0 && names.get(0) == "host3") Arrays.asList("/rack2")
-    else Arrays.asList("/rack1")
+  override def resolve(conf: Configuration, hostName: String): String = {
+    if (hostName == "host3") "/rack2" else "/rack1"
   }
 
-  override def reloadCachedMappings() {}
-
-  def reloadCachedMappings(names: JList[String]) {}
 }
 
 class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach {
   val conf = new YarnConfiguration()
-  conf.setClass(
-    CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY,
-    classOf[MockResolver], classOf[DNSToSwitchMapping])
-
   val sparkConf = new SparkConf()
   sparkConf.set("spark.driver.host", "localhost")
   sparkConf.set("spark.driver.port", "4040")
@@ -107,7 +96,8 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
       rmClient,
       appAttemptId,
       new SecurityManager(sparkConf),
-      Map())
+      Map(),
+      new MockResolver())
   }
 
   def createContainer(host: String): Container = {