diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index aa29acfd704613d0e7eade13ea588c2d83b39853..982b83324e0fce46b8377faf3a142c986f221192 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -20,7 +20,8 @@ package org.apache.spark.storage
 import java.io._
 import java.nio.ByteBuffer
 
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable
+import scala.collection.mutable.HashMap
 import scala.concurrent.{Await, ExecutionContext, Future}
 import scala.concurrent.duration._
 import scala.reflect.ClassTag
@@ -44,6 +45,7 @@ import org.apache.spark.unsafe.Platform
 import org.apache.spark.util._
 import org.apache.spark.util.io.ChunkedByteBuffer
 
+
 /* Class for returning a fetched block and associated metrics. */
 private[spark] class BlockResult(
     val data: Iterator[Any],
@@ -147,6 +149,8 @@ private[spark] class BlockManager(
   private val peerFetchLock = new Object
   private var lastPeerFetchTime = 0L
 
+  private var blockReplicationPolicy: BlockReplicationPolicy = _
+
   /**
    * Initializes the BlockManager with the given appId. This is not performed in the constructor as
    * the appId may not be known at BlockManager instantiation time (in particular for the driver,
@@ -160,8 +164,24 @@ private[spark] class BlockManager(
     blockTransferService.init(this)
     shuffleClient.init(appId)
 
-    blockManagerId = BlockManagerId(
-      executorId, blockTransferService.hostName, blockTransferService.port)
+    blockReplicationPolicy = {
+      val priorityClass = conf.get(
+        "spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName)
+      val clazz = Utils.classForName(priorityClass)
+      val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy]
+      logInfo(s"Using $priorityClass for block replication policy")
+      ret
+    }
+
+    val id =
+      BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port, None)
+
+    val idFromMaster = master.registerBlockManager(
+      id,
+      maxMemory,
+      slaveEndpoint)
+
+    blockManagerId = if (idFromMaster != null) idFromMaster else id
 
     shuffleServerId = if (externalShuffleServiceEnabled) {
       logInfo(s"external shuffle service port = $externalShuffleServicePort")
@@ -170,12 +190,12 @@ private[spark] class BlockManager(
       blockManagerId
     }
 
-    master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
-
     // Register Executors' configuration with the local shuffle service, if one should exist.
     if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
       registerWithExternalShuffleServer()
     }
+
+    logInfo(s"Initialized BlockManager: $blockManagerId")
   }
 
   private def registerWithExternalShuffleServer() {
@@ -1111,7 +1131,7 @@ private[spark] class BlockManager(
   }
 
   /**
-   * Replicate block to another node. Not that this is a blocking call that returns after
+   * Replicate block to another node. Note that this is a blocking call that returns after
    * the block has been replicated.
    */
   private def replicate(
@@ -1119,101 +1139,78 @@ private[spark] class BlockManager(
       data: ChunkedByteBuffer,
       level: StorageLevel,
       classTag: ClassTag[_]): Unit = {
+
     val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1)
-    val numPeersToReplicateTo = level.replication - 1
-    val peersForReplication = new ArrayBuffer[BlockManagerId]
-    val peersReplicatedTo = new ArrayBuffer[BlockManagerId]
-    val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId]
     val tLevel = StorageLevel(
       useDisk = level.useDisk,
       useMemory = level.useMemory,
       useOffHeap = level.useOffHeap,
       deserialized = level.deserialized,
       replication = 1)
-    val startTime = System.currentTimeMillis
-    val random = new Random(blockId.hashCode)
-
-    var replicationFailed = false
-    var failures = 0
-    var done = false
-
-    // Get cached list of peers
-    peersForReplication ++= getPeers(forceFetch = false)
-
-    // Get a random peer. Note that this selection of a peer is deterministic on the block id.
-    // So assuming the list of peers does not change and no replication failures,
-    // if there are multiple attempts in the same node to replicate the same block,
-    // the same set of peers will be selected.
-    def getRandomPeer(): Option[BlockManagerId] = {
-      // If replication had failed, then force update the cached list of peers and remove the peers
-      // that have been already used
-      if (replicationFailed) {
-        peersForReplication.clear()
-        peersForReplication ++= getPeers(forceFetch = true)
-        peersForReplication --= peersReplicatedTo
-        peersForReplication --= peersFailedToReplicateTo
-      }
-      if (!peersForReplication.isEmpty) {
-        Some(peersForReplication(random.nextInt(peersForReplication.size)))
-      } else {
-        None
-      }
-    }
 
-    // One by one choose a random peer and try uploading the block to it
-    // If replication fails (e.g., target peer is down), force the list of cached peers
-    // to be re-fetched from driver and then pick another random peer for replication. Also
-    // temporarily black list the peer for which replication failed.
-    //
-    // This selection of a peer and replication is continued in a loop until one of the
-    // following 3 conditions is fulfilled:
-    // (i) specified number of peers have been replicated to
-    // (ii) too many failures in replicating to peers
-    // (iii) no peer left to replicate to
-    //
-    while (!done) {
-      getRandomPeer() match {
-        case Some(peer) =>
-          try {
-            val onePeerStartTime = System.currentTimeMillis
-            logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer")
-            blockTransferService.uploadBlockSync(
-              peer.host,
-              peer.port,
-              peer.executorId,
-              blockId,
-              new NettyManagedBuffer(data.toNetty),
-              tLevel,
-              classTag)
-            logTrace(s"Replicated $blockId of ${data.size} bytes to $peer in %s ms"
-              .format(System.currentTimeMillis - onePeerStartTime))
-            peersReplicatedTo += peer
-            peersForReplication -= peer
-            replicationFailed = false
-            if (peersReplicatedTo.size == numPeersToReplicateTo) {
-              done = true  // specified number of peers have been replicated to
-            }
-          } catch {
-            case NonFatal(e) =>
-              logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e)
-              failures += 1
-              replicationFailed = true
-              peersFailedToReplicateTo += peer
-              if (failures > maxReplicationFailures) { // too many failures in replicating to peers
-                done = true
-              }
+    val numPeersToReplicateTo = level.replication - 1
+
+    val startTime = System.nanoTime
+
+    var peersReplicatedTo = mutable.HashSet.empty[BlockManagerId]
+    var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId]
+    var numFailures = 0
+
+    var peersForReplication = blockReplicationPolicy.prioritize(
+      blockManagerId,
+      getPeers(false),
+      mutable.HashSet.empty,
+      blockId,
+      numPeersToReplicateTo)
+
+    while(numFailures <= maxReplicationFailures &&
+        !peersForReplication.isEmpty &&
+        peersReplicatedTo.size != numPeersToReplicateTo) {
+      val peer = peersForReplication.head
+      try {
+        val onePeerStartTime = System.nanoTime
+        logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer")
+        blockTransferService.uploadBlockSync(
+          peer.host,
+          peer.port,
+          peer.executorId,
+          blockId,
+          new NettyManagedBuffer(data.toNetty),
+          tLevel,
+          classTag)
+        logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" +
+          s" in ${(System.nanoTime - onePeerStartTime).toDouble / 1e6} ms")
+        peersForReplication = peersForReplication.tail
+        peersReplicatedTo += peer
+      } catch {
+        case NonFatal(e) =>
+          logWarning(s"Failed to replicate $blockId to $peer, failure #$numFailures", e)
+          peersFailedToReplicateTo += peer
+          // we have a failed replication, so we get the list of peers again
+          // we don't want peers we have already replicated to and the ones that
+          // have failed previously
+          val filteredPeers = getPeers(true).filter { p =>
+            !peersFailedToReplicateTo.contains(p) && !peersReplicatedTo.contains(p)
           }
-        case None => // no peer left to replicate to
-          done = true
+
+          numFailures += 1
+          peersForReplication = blockReplicationPolicy.prioritize(
+            blockManagerId,
+            filteredPeers,
+            peersReplicatedTo,
+            blockId,
+            numPeersToReplicateTo - peersReplicatedTo.size)
       }
     }
-    val timeTakeMs = (System.currentTimeMillis - startTime)
+
     logDebug(s"Replicating $blockId of ${data.size} bytes to " +
-      s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms")
+      s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms")
     if (peersReplicatedTo.size < numPeersToReplicateTo) {
       logWarning(s"Block $blockId replicated to only " +
         s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers")
     }
+
+    logDebug(s"block $blockId replicated to ${peersReplicatedTo.mkString(", ")}")
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index f255f5be63fcf48ae661ffd7c7243942fda5117d..c37a3604d28fadd33f985d7e451d5ebe736c33d7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -37,10 +37,11 @@ import org.apache.spark.util.Utils
 class BlockManagerId private (
     private var executorId_ : String,
     private var host_ : String,
-    private var port_ : Int)
+    private var port_ : Int,
+    private var topologyInfo_ : Option[String])
   extends Externalizable {
 
-  private def this() = this(null, null, 0)  // For deserialization only
+  private def this() = this(null, null, 0, None)  // For deserialization only
 
   def executorId: String = executorId_
 
@@ -60,6 +61,8 @@ class BlockManagerId private (
 
   def port: Int = port_
 
+  def topologyInfo: Option[String] = topologyInfo_
+
   def isDriver: Boolean = {
     executorId == SparkContext.DRIVER_IDENTIFIER ||
       executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER
@@ -69,24 +72,33 @@ class BlockManagerId private (
     out.writeUTF(executorId_)
     out.writeUTF(host_)
     out.writeInt(port_)
+    out.writeBoolean(topologyInfo_.isDefined)
+    // we only write topologyInfo if we have it
+    topologyInfo.foreach(out.writeUTF(_))
   }
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     executorId_ = in.readUTF()
     host_ = in.readUTF()
     port_ = in.readInt()
+    val isTopologyInfoAvailable = in.readBoolean()
+    topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None
   }
 
   @throws(classOf[IOException])
   private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
 
-  override def toString: String = s"BlockManagerId($executorId, $host, $port)"
+  override def toString: String = s"BlockManagerId($executorId, $host, $port, $topologyInfo)"
 
-  override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port
+  override def hashCode: Int =
+    ((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode
 
   override def equals(that: Any): Boolean = that match {
     case id: BlockManagerId =>
-      executorId == id.executorId && port == id.port && host == id.host
+      executorId == id.executorId &&
+        port == id.port &&
+        host == id.host &&
+        topologyInfo == id.topologyInfo
     case _ =>
       false
   }
@@ -101,10 +113,18 @@ private[spark] object BlockManagerId {
    * @param execId ID of the executor.
    * @param host Host name of the block manager.
    * @param port Port of the block manager.
+   * @param topologyInfo topology information for the blockmanager, if available
+   *                     This can be network topology information for use while choosing peers
+   *                     while replicating data blocks. More information available here:
+   *                     [[org.apache.spark.storage.TopologyMapper]]
    * @return A new [[org.apache.spark.storage.BlockManagerId]].
    */
-  def apply(execId: String, host: String, port: Int): BlockManagerId =
-    getCachedBlockManagerId(new BlockManagerId(execId, host, port))
+  def apply(
+      execId: String,
+      host: String,
+      port: Int,
+      topologyInfo: Option[String] = None): BlockManagerId =
+    getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo))
 
   def apply(in: ObjectInput): BlockManagerId = {
     val obj = new BlockManagerId()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 8655cf10fc28f7377170d1f54568348cc7b5a819..7a600068912b19515abe46fe61b800ff8f47414a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -50,12 +50,20 @@ class BlockManagerMaster(
     logInfo("Removal of executor " + execId + " requested")
   }
 
-  /** Register the BlockManager's id with the driver. */
+  /**
+   * Register the BlockManager's id with the driver. The input BlockManagerId does not contain
+   * topology information. This information is obtained from the master and we respond with an
+   * updated BlockManagerId fleshed out with this information.
+   */
   def registerBlockManager(
-      blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = {
+      blockManagerId: BlockManagerId,
+      maxMemSize: Long,
+      slaveEndpoint: RpcEndpointRef): BlockManagerId = {
     logInfo(s"Registering BlockManager $blockManagerId")
-    tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
-    logInfo(s"Registered BlockManager $blockManagerId")
+    val updatedId = driverEndpoint.askWithRetry[BlockManagerId](
+      RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
+    logInfo(s"Registered BlockManager $updatedId")
+    updatedId
   }
 
   def updateBlockInfo(
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 8fa12150114dba83f2a335c9d488ac399c5ce2a3..145c434a4f0cf4a7885feee6ee2e79485fdd9d57 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -55,10 +55,21 @@ class BlockManagerMasterEndpoint(
   private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool")
   private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
 
+  private val topologyMapper = {
+    val topologyMapperClassName = conf.get(
+      "spark.storage.replication.topologyMapper", classOf[DefaultTopologyMapper].getName)
+    val clazz = Utils.classForName(topologyMapperClassName)
+    val mapper =
+      clazz.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[TopologyMapper]
+    logInfo(s"Using $topologyMapperClassName for getting topology information")
+    mapper
+  }
+
+  logInfo("BlockManagerMasterEndpoint up")
+
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
     case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) =>
-      register(blockManagerId, maxMemSize, slaveEndpoint)
-      context.reply(true)
+      context.reply(register(blockManagerId, maxMemSize, slaveEndpoint))
 
     case _updateBlockInfo @
         UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
@@ -298,7 +309,21 @@ class BlockManagerMasterEndpoint(
     ).map(_.flatten.toSeq)
   }
 
-  private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) {
+  /**
+   * Returns the BlockManagerId with topology information populated, if available.
+   */
+  private def register(
+      idWithoutTopologyInfo: BlockManagerId,
+      maxMemSize: Long,
+      slaveEndpoint: RpcEndpointRef): BlockManagerId = {
+    // the dummy id is not expected to contain the topology information.
+    // we get that info here and respond back with a more fleshed out block manager id
+    val id = BlockManagerId(
+      idWithoutTopologyInfo.executorId,
+      idWithoutTopologyInfo.host,
+      idWithoutTopologyInfo.port,
+      topologyMapper.getTopologyForHost(idWithoutTopologyInfo.host))
+
     val time = System.currentTimeMillis()
     if (!blockManagerInfo.contains(id)) {
       blockManagerIdByExecutor.get(id.executorId) match {
@@ -318,6 +343,7 @@ class BlockManagerMasterEndpoint(
         id, System.currentTimeMillis(), maxMemSize, slaveEndpoint)
     }
     listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
+    id
   }
 
   private def updateBlockInfo(
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala
new file mode 100644
index 0000000000000000000000000000000000000000..bf087af16a5b1400b9ec3c34eaa1901f1f1aaee4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
+
+/**
+ * ::DeveloperApi::
+ * BlockReplicationPrioritization provides logic for prioritizing a sequence of peers for
+ * replicating blocks. BlockManager will replicate to each peer returned in order until the
+ * desired replication order is reached. If a replication fails, prioritize() will be called
+ * again to get a fresh prioritization.
+ */
+@DeveloperApi
+trait BlockReplicationPolicy {
+
+  /**
+   * Method to prioritize a bunch of candidate peers of a block
+   *
+   * @param blockManagerId Id of the current BlockManager for self identification
+   * @param peers A list of peers of a BlockManager
+   * @param peersReplicatedTo Set of peers already replicated to
+   * @param blockId BlockId of the block being replicated. This can be used as a source of
+   *                randomness if needed.
+   * @param numReplicas Number of peers we need to replicate to
+   * @return A prioritized list of peers. Lower the index of a peer, higher its priority.
+   *         This returns a list of size at most `numPeersToReplicateTo`.
+   */
+  def prioritize(
+      blockManagerId: BlockManagerId,
+      peers: Seq[BlockManagerId],
+      peersReplicatedTo: mutable.HashSet[BlockManagerId],
+      blockId: BlockId,
+      numReplicas: Int): List[BlockManagerId]
+}
+
+@DeveloperApi
+class RandomBlockReplicationPolicy
+  extends BlockReplicationPolicy
+  with Logging {
+
+  /**
+   * Method to prioritize a bunch of candidate peers of a block. This is a basic implementation,
+   * that just makes sure we put blocks on different hosts, if possible
+   *
+   * @param blockManagerId Id of the current BlockManager for self identification
+   * @param peers A list of peers of a BlockManager
+   * @param peersReplicatedTo Set of peers already replicated to
+   * @param blockId BlockId of the block being replicated. This can be used as a source of
+   *                randomness if needed.
+   * @return A prioritized list of peers. Lower the index of a peer, higher its priority
+   */
+  override def prioritize(
+      blockManagerId: BlockManagerId,
+      peers: Seq[BlockManagerId],
+      peersReplicatedTo: mutable.HashSet[BlockManagerId],
+      blockId: BlockId,
+      numReplicas: Int): List[BlockManagerId] = {
+    val random = new Random(blockId.hashCode)
+    logDebug(s"Input peers : ${peers.mkString(", ")}")
+    val prioritizedPeers = if (peers.size > numReplicas) {
+      getSampleIds(peers.size, numReplicas, random).map(peers(_))
+    } else {
+      if (peers.size < numReplicas) {
+        logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.")
+      }
+      random.shuffle(peers).toList
+    }
+    logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}")
+    prioritizedPeers
+  }
+
+  /**
+   * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while
+   * minimizing space usage
+   * [[http://math.stackexchange.com/questions/178690/
+   * whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]]
+   *
+   * @param n total number of indices
+   * @param m number of samples needed
+   * @param r random number generator
+   * @return list of m random unique indices
+   */
+  private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = {
+    val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) =>
+      val t = r.nextInt(i) + 1
+      if (set.contains(t)) set + i else set + t
+    }
+    // we shuffle the result to ensure a random arrangement within the sample
+    // to avoid any bias from set implementations
+    r.shuffle(indices.map(_ - 1).toList)
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala
new file mode 100644
index 0000000000000000000000000000000000000000..a0f0fdef8e9484b2f9b828f7665f1e91e1b8929b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.SparkConf
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * ::DeveloperApi::
+ * TopologyMapper provides topology information for a given host
+ * @param conf SparkConf to get required properties, if needed
+ */
+@DeveloperApi
+abstract class TopologyMapper(conf: SparkConf) {
+  /**
+   * Gets the topology information given the host name
+   *
+   * @param hostname Hostname
+   * @return topology information for the given hostname. One can use a 'topology delimiter'
+   *         to make this topology information nested.
+   *         For example : ‘/myrack/myhost’, where ‘/’ is the topology delimiter,
+   *         ‘myrack’ is the topology identifier, and ‘myhost’ is the individual host.
+   *         This function only returns the topology information without the hostname.
+   *         This information can be used when choosing executors for block replication
+   *         to discern executors from a different rack than a candidate executor, for example.
+   *
+   *         An implementation can choose to use empty strings or None in case topology info
+   *         is not available. This would imply that all such executors belong to the same rack.
+   */
+  def getTopologyForHost(hostname: String): Option[String]
+}
+
+/**
+ * A TopologyMapper that assumes all nodes are in the same rack
+ */
+@DeveloperApi
+class DefaultTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging {
+  override def getTopologyForHost(hostname: String): Option[String] = {
+    logDebug(s"Got a request for $hostname")
+    None
+  }
+}
+
+/**
+ * A simple file based topology mapper. This expects topology information provided as a
+ * [[java.util.Properties]] file. The name of the file is obtained from SparkConf property
+ * `spark.storage.replication.topologyFile`. To use this topology mapper, set the
+ * `spark.storage.replication.topologyMapper` property to
+ * [[org.apache.spark.storage.FileBasedTopologyMapper]]
+ * @param conf SparkConf object
+ */
+@DeveloperApi
+class FileBasedTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging {
+  val topologyFile = conf.getOption("spark.storage.replication.topologyFile")
+  require(topologyFile.isDefined, "Please specify topology file via " +
+    "spark.storage.replication.topologyFile for FileBasedTopologyMapper.")
+  val topologyMap = Utils.getPropertiesFromFile(topologyFile.get)
+
+  override def getTopologyForHost(hostname: String): Option[String] = {
+    val topology = topologyMap.get(hostname)
+    if (topology.isDefined) {
+      logDebug(s"$hostname -> ${topology.get}")
+    } else {
+      logWarning(s"$hostname does not have any topology information")
+    }
+    topology
+  }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index e1c1787cbd15e30b49e0d5c32f2adfb0a810162b..f4bfdc2fd69a93cd85088bd61fb483889774796a 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -346,6 +346,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite
     }
   }
 
+
+
   /**
    * Test replication of blocks with different storage levels (various combinations of
    * memory, disk & serialization). For each storage level, this function tests every store
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..800c3899f1a724484ab3cc71548cd917fac395a7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import scala.collection.mutable
+
+import org.scalatest.{BeforeAndAfter, Matchers}
+
+import org.apache.spark.{LocalSparkContext, SparkFunSuite}
+
+class BlockReplicationPolicySuite extends SparkFunSuite
+  with Matchers
+  with BeforeAndAfter
+  with LocalSparkContext {
+
+  // Implicitly convert strings to BlockIds for test clarity.
+  private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
+
+  /**
+   * Test if we get the required number of peers when using random sampling from
+   * RandomBlockReplicationPolicy
+   */
+  test(s"block replication - random block replication policy") {
+    val numBlockManagers = 10
+    val storeSize = 1000
+    val blockManagers = (1 to numBlockManagers).map { i =>
+      BlockManagerId(s"store-$i", "localhost", 1000 + i, None)
+    }
+    val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None)
+    val replicationPolicy = new RandomBlockReplicationPolicy
+    val blockId = "test-block"
+
+    (1 to 10).foreach {numReplicas =>
+      logDebug(s"Num replicas : $numReplicas")
+      val randomPeers = replicationPolicy.prioritize(
+        candidateBlockManager,
+        blockManagers,
+        mutable.HashSet.empty[BlockManagerId],
+        blockId,
+        numReplicas
+      )
+      logDebug(s"Random peers : ${randomPeers.mkString(", ")}")
+      assert(randomPeers.toSet.size === numReplicas)
+
+      // choosing n peers out of n
+      val secondPass = replicationPolicy.prioritize(
+        candidateBlockManager,
+        randomPeers,
+        mutable.HashSet.empty[BlockManagerId],
+        blockId,
+        numReplicas
+      )
+      logDebug(s"Random peers : ${secondPass.mkString(", ")}")
+      assert(secondPass.toSet.size === numReplicas)
+    }
+
+  }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala b/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..bbd252d7be7ea110aacdf5ef81f69e6bf3c2942e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/TopologyMapperSuite.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{File, FileOutputStream}
+
+import org.scalatest.{BeforeAndAfter, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.util.Utils
+
+class TopologyMapperSuite  extends SparkFunSuite
+  with Matchers
+  with BeforeAndAfter
+  with LocalSparkContext {
+
+  test("File based Topology Mapper") {
+    val numHosts = 100
+    val numRacks = 4
+    val props = (1 to numHosts).map{i => s"host-$i" -> s"rack-${i % numRacks}"}.toMap
+    val propsFile = createPropertiesFile(props)
+
+    val sparkConf = (new SparkConf(false))
+    sparkConf.set("spark.storage.replication.topologyFile", propsFile.getAbsolutePath)
+    val topologyMapper = new FileBasedTopologyMapper(sparkConf)
+
+    props.foreach {case (host, topology) =>
+      val obtainedTopology = topologyMapper.getTopologyForHost(host)
+      assert(obtainedTopology.isDefined)
+      assert(obtainedTopology.get === topology)
+    }
+
+    // we get None for hosts not in the file
+    assert(topologyMapper.getTopologyForHost("host").isEmpty)
+
+    cleanup(propsFile)
+  }
+
+  def createPropertiesFile(props: Map[String, String]): File = {
+    val testFile = new File(Utils.createTempDir(), "TopologyMapperSuite-test").getAbsoluteFile
+    val fileOS = new FileOutputStream(testFile)
+    props.foreach{case (k, v) => fileOS.write(s"$k=$v\n".getBytes)}
+    fileOS.close
+    testFile
+  }
+
+  def cleanup(testFile: File): Unit = {
+    testFile.getParentFile.listFiles.filter { file =>
+      file.getName.startsWith(testFile.getName)
+    }.foreach { _.delete() }
+  }
+
+}