Skip to content
Snippets Groups Projects
Commit a26afd52 authored by Shubham Chopra's avatar Shubham Chopra Committed by Reynold Xin
Browse files

[SPARK-15353][CORE] Making peer selection for block replication pluggable

## What changes were proposed in this pull request?

This PR makes block replication strategies pluggable. It provides two trait that can be implemented, one that maps a host to its topology and is used in the master, and the second that helps prioritize a list of peers for block replication and would run in the executors.

This patch contains default implementations of these traits that make sure current Spark behavior is unchanged.

## How was this patch tested?

This patch should not change Spark behavior in any way, and was tested with unit tests for storage.

Author: Shubham Chopra <schopra31@bloomberg.net>

Closes #13152 from shubhamchopra/RackAwareBlockReplication.
parent 81455a9c
No related branches found
No related tags found
No related merge requests found
Showing
with 492 additions and 99 deletions
......@@ -20,7 +20,8 @@ package org.apache.spark.storage
import java.io._
import java.nio.ByteBuffer
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.mutable
import scala.collection.mutable.HashMap
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._
import scala.reflect.ClassTag
......@@ -44,6 +45,7 @@ import org.apache.spark.unsafe.Platform
import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer
/* Class for returning a fetched block and associated metrics. */
private[spark] class BlockResult(
val data: Iterator[Any],
......@@ -147,6 +149,8 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
private var blockReplicationPolicy: BlockReplicationPolicy = _
/**
* Initializes the BlockManager with the given appId. This is not performed in the constructor as
* the appId may not be known at BlockManager instantiation time (in particular for the driver,
......@@ -160,8 +164,24 @@ private[spark] class BlockManager(
blockTransferService.init(this)
shuffleClient.init(appId)
blockManagerId = BlockManagerId(
executorId, blockTransferService.hostName, blockTransferService.port)
blockReplicationPolicy = {
val priorityClass = conf.get(
"spark.storage.replication.policy", classOf[RandomBlockReplicationPolicy].getName)
val clazz = Utils.classForName(priorityClass)
val ret = clazz.newInstance.asInstanceOf[BlockReplicationPolicy]
logInfo(s"Using $priorityClass for block replication policy")
ret
}
val id =
BlockManagerId(executorId, blockTransferService.hostName, blockTransferService.port, None)
val idFromMaster = master.registerBlockManager(
id,
maxMemory,
slaveEndpoint)
blockManagerId = if (idFromMaster != null) idFromMaster else id
shuffleServerId = if (externalShuffleServiceEnabled) {
logInfo(s"external shuffle service port = $externalShuffleServicePort")
......@@ -170,12 +190,12 @@ private[spark] class BlockManager(
blockManagerId
}
master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
// Register Executors' configuration with the local shuffle service, if one should exist.
if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
registerWithExternalShuffleServer()
}
logInfo(s"Initialized BlockManager: $blockManagerId")
}
private def registerWithExternalShuffleServer() {
......@@ -1111,7 +1131,7 @@ private[spark] class BlockManager(
}
/**
* Replicate block to another node. Not that this is a blocking call that returns after
* Replicate block to another node. Note that this is a blocking call that returns after
* the block has been replicated.
*/
private def replicate(
......@@ -1119,101 +1139,78 @@ private[spark] class BlockManager(
data: ChunkedByteBuffer,
level: StorageLevel,
classTag: ClassTag[_]): Unit = {
val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1)
val numPeersToReplicateTo = level.replication - 1
val peersForReplication = new ArrayBuffer[BlockManagerId]
val peersReplicatedTo = new ArrayBuffer[BlockManagerId]
val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId]
val tLevel = StorageLevel(
useDisk = level.useDisk,
useMemory = level.useMemory,
useOffHeap = level.useOffHeap,
deserialized = level.deserialized,
replication = 1)
val startTime = System.currentTimeMillis
val random = new Random(blockId.hashCode)
var replicationFailed = false
var failures = 0
var done = false
// Get cached list of peers
peersForReplication ++= getPeers(forceFetch = false)
// Get a random peer. Note that this selection of a peer is deterministic on the block id.
// So assuming the list of peers does not change and no replication failures,
// if there are multiple attempts in the same node to replicate the same block,
// the same set of peers will be selected.
def getRandomPeer(): Option[BlockManagerId] = {
// If replication had failed, then force update the cached list of peers and remove the peers
// that have been already used
if (replicationFailed) {
peersForReplication.clear()
peersForReplication ++= getPeers(forceFetch = true)
peersForReplication --= peersReplicatedTo
peersForReplication --= peersFailedToReplicateTo
}
if (!peersForReplication.isEmpty) {
Some(peersForReplication(random.nextInt(peersForReplication.size)))
} else {
None
}
}
// One by one choose a random peer and try uploading the block to it
// If replication fails (e.g., target peer is down), force the list of cached peers
// to be re-fetched from driver and then pick another random peer for replication. Also
// temporarily black list the peer for which replication failed.
//
// This selection of a peer and replication is continued in a loop until one of the
// following 3 conditions is fulfilled:
// (i) specified number of peers have been replicated to
// (ii) too many failures in replicating to peers
// (iii) no peer left to replicate to
//
while (!done) {
getRandomPeer() match {
case Some(peer) =>
try {
val onePeerStartTime = System.currentTimeMillis
logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer")
blockTransferService.uploadBlockSync(
peer.host,
peer.port,
peer.executorId,
blockId,
new NettyManagedBuffer(data.toNetty),
tLevel,
classTag)
logTrace(s"Replicated $blockId of ${data.size} bytes to $peer in %s ms"
.format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
peersForReplication -= peer
replicationFailed = false
if (peersReplicatedTo.size == numPeersToReplicateTo) {
done = true // specified number of peers have been replicated to
}
} catch {
case NonFatal(e) =>
logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e)
failures += 1
replicationFailed = true
peersFailedToReplicateTo += peer
if (failures > maxReplicationFailures) { // too many failures in replicating to peers
done = true
}
val numPeersToReplicateTo = level.replication - 1
val startTime = System.nanoTime
var peersReplicatedTo = mutable.HashSet.empty[BlockManagerId]
var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId]
var numFailures = 0
var peersForReplication = blockReplicationPolicy.prioritize(
blockManagerId,
getPeers(false),
mutable.HashSet.empty,
blockId,
numPeersToReplicateTo)
while(numFailures <= maxReplicationFailures &&
!peersForReplication.isEmpty &&
peersReplicatedTo.size != numPeersToReplicateTo) {
val peer = peersForReplication.head
try {
val onePeerStartTime = System.nanoTime
logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer")
blockTransferService.uploadBlockSync(
peer.host,
peer.port,
peer.executorId,
blockId,
new NettyManagedBuffer(data.toNetty),
tLevel,
classTag)
logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" +
s" in ${(System.nanoTime - onePeerStartTime).toDouble / 1e6} ms")
peersForReplication = peersForReplication.tail
peersReplicatedTo += peer
} catch {
case NonFatal(e) =>
logWarning(s"Failed to replicate $blockId to $peer, failure #$numFailures", e)
peersFailedToReplicateTo += peer
// we have a failed replication, so we get the list of peers again
// we don't want peers we have already replicated to and the ones that
// have failed previously
val filteredPeers = getPeers(true).filter { p =>
!peersFailedToReplicateTo.contains(p) && !peersReplicatedTo.contains(p)
}
case None => // no peer left to replicate to
done = true
numFailures += 1
peersForReplication = blockReplicationPolicy.prioritize(
blockManagerId,
filteredPeers,
peersReplicatedTo,
blockId,
numPeersToReplicateTo - peersReplicatedTo.size)
}
}
val timeTakeMs = (System.currentTimeMillis - startTime)
logDebug(s"Replicating $blockId of ${data.size} bytes to " +
s"${peersReplicatedTo.size} peer(s) took $timeTakeMs ms")
s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms")
if (peersReplicatedTo.size < numPeersToReplicateTo) {
logWarning(s"Block $blockId replicated to only " +
s"${peersReplicatedTo.size} peer(s) instead of $numPeersToReplicateTo peers")
}
logDebug(s"block $blockId replicated to ${peersReplicatedTo.mkString(", ")}")
}
/**
......
......@@ -37,10 +37,11 @@ import org.apache.spark.util.Utils
class BlockManagerId private (
private var executorId_ : String,
private var host_ : String,
private var port_ : Int)
private var port_ : Int,
private var topologyInfo_ : Option[String])
extends Externalizable {
private def this() = this(null, null, 0) // For deserialization only
private def this() = this(null, null, 0, None) // For deserialization only
def executorId: String = executorId_
......@@ -60,6 +61,8 @@ class BlockManagerId private (
def port: Int = port_
def topologyInfo: Option[String] = topologyInfo_
def isDriver: Boolean = {
executorId == SparkContext.DRIVER_IDENTIFIER ||
executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER
......@@ -69,24 +72,33 @@ class BlockManagerId private (
out.writeUTF(executorId_)
out.writeUTF(host_)
out.writeInt(port_)
out.writeBoolean(topologyInfo_.isDefined)
// we only write topologyInfo if we have it
topologyInfo.foreach(out.writeUTF(_))
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
executorId_ = in.readUTF()
host_ = in.readUTF()
port_ = in.readInt()
val isTopologyInfoAvailable = in.readBoolean()
topologyInfo_ = if (isTopologyInfoAvailable) Option(in.readUTF()) else None
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
override def toString: String = s"BlockManagerId($executorId, $host, $port)"
override def toString: String = s"BlockManagerId($executorId, $host, $port, $topologyInfo)"
override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port
override def hashCode: Int =
((executorId.hashCode * 41 + host.hashCode) * 41 + port) * 41 + topologyInfo.hashCode
override def equals(that: Any): Boolean = that match {
case id: BlockManagerId =>
executorId == id.executorId && port == id.port && host == id.host
executorId == id.executorId &&
port == id.port &&
host == id.host &&
topologyInfo == id.topologyInfo
case _ =>
false
}
......@@ -101,10 +113,18 @@ private[spark] object BlockManagerId {
* @param execId ID of the executor.
* @param host Host name of the block manager.
* @param port Port of the block manager.
* @param topologyInfo topology information for the blockmanager, if available
* This can be network topology information for use while choosing peers
* while replicating data blocks. More information available here:
* [[org.apache.spark.storage.TopologyMapper]]
* @return A new [[org.apache.spark.storage.BlockManagerId]].
*/
def apply(execId: String, host: String, port: Int): BlockManagerId =
getCachedBlockManagerId(new BlockManagerId(execId, host, port))
def apply(
execId: String,
host: String,
port: Int,
topologyInfo: Option[String] = None): BlockManagerId =
getCachedBlockManagerId(new BlockManagerId(execId, host, port, topologyInfo))
def apply(in: ObjectInput): BlockManagerId = {
val obj = new BlockManagerId()
......
......@@ -50,12 +50,20 @@ class BlockManagerMaster(
logInfo("Removal of executor " + execId + " requested")
}
/** Register the BlockManager's id with the driver. */
/**
* Register the BlockManager's id with the driver. The input BlockManagerId does not contain
* topology information. This information is obtained from the master and we respond with an
* updated BlockManagerId fleshed out with this information.
*/
def registerBlockManager(
blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = {
blockManagerId: BlockManagerId,
maxMemSize: Long,
slaveEndpoint: RpcEndpointRef): BlockManagerId = {
logInfo(s"Registering BlockManager $blockManagerId")
tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
logInfo(s"Registered BlockManager $blockManagerId")
val updatedId = driverEndpoint.askWithRetry[BlockManagerId](
RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
logInfo(s"Registered BlockManager $updatedId")
updatedId
}
def updateBlockInfo(
......
......@@ -55,10 +55,21 @@ class BlockManagerMasterEndpoint(
private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool")
private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
private val topologyMapper = {
val topologyMapperClassName = conf.get(
"spark.storage.replication.topologyMapper", classOf[DefaultTopologyMapper].getName)
val clazz = Utils.classForName(topologyMapperClassName)
val mapper =
clazz.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[TopologyMapper]
logInfo(s"Using $topologyMapperClassName for getting topology information")
mapper
}
logInfo("BlockManagerMasterEndpoint up")
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) =>
register(blockManagerId, maxMemSize, slaveEndpoint)
context.reply(true)
context.reply(register(blockManagerId, maxMemSize, slaveEndpoint))
case _updateBlockInfo @
UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
......@@ -298,7 +309,21 @@ class BlockManagerMasterEndpoint(
).map(_.flatten.toSeq)
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) {
/**
* Returns the BlockManagerId with topology information populated, if available.
*/
private def register(
idWithoutTopologyInfo: BlockManagerId,
maxMemSize: Long,
slaveEndpoint: RpcEndpointRef): BlockManagerId = {
// the dummy id is not expected to contain the topology information.
// we get that info here and respond back with a more fleshed out block manager id
val id = BlockManagerId(
idWithoutTopologyInfo.executorId,
idWithoutTopologyInfo.host,
idWithoutTopologyInfo.port,
topologyMapper.getTopologyForHost(idWithoutTopologyInfo.host))
val time = System.currentTimeMillis()
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
......@@ -318,6 +343,7 @@ class BlockManagerMasterEndpoint(
id, System.currentTimeMillis(), maxMemSize, slaveEndpoint)
}
listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
id
}
private def updateBlockInfo(
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import scala.collection.mutable
import scala.util.Random
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
/**
* ::DeveloperApi::
* BlockReplicationPrioritization provides logic for prioritizing a sequence of peers for
* replicating blocks. BlockManager will replicate to each peer returned in order until the
* desired replication order is reached. If a replication fails, prioritize() will be called
* again to get a fresh prioritization.
*/
@DeveloperApi
trait BlockReplicationPolicy {
/**
* Method to prioritize a bunch of candidate peers of a block
*
* @param blockManagerId Id of the current BlockManager for self identification
* @param peers A list of peers of a BlockManager
* @param peersReplicatedTo Set of peers already replicated to
* @param blockId BlockId of the block being replicated. This can be used as a source of
* randomness if needed.
* @param numReplicas Number of peers we need to replicate to
* @return A prioritized list of peers. Lower the index of a peer, higher its priority.
* This returns a list of size at most `numPeersToReplicateTo`.
*/
def prioritize(
blockManagerId: BlockManagerId,
peers: Seq[BlockManagerId],
peersReplicatedTo: mutable.HashSet[BlockManagerId],
blockId: BlockId,
numReplicas: Int): List[BlockManagerId]
}
@DeveloperApi
class RandomBlockReplicationPolicy
extends BlockReplicationPolicy
with Logging {
/**
* Method to prioritize a bunch of candidate peers of a block. This is a basic implementation,
* that just makes sure we put blocks on different hosts, if possible
*
* @param blockManagerId Id of the current BlockManager for self identification
* @param peers A list of peers of a BlockManager
* @param peersReplicatedTo Set of peers already replicated to
* @param blockId BlockId of the block being replicated. This can be used as a source of
* randomness if needed.
* @return A prioritized list of peers. Lower the index of a peer, higher its priority
*/
override def prioritize(
blockManagerId: BlockManagerId,
peers: Seq[BlockManagerId],
peersReplicatedTo: mutable.HashSet[BlockManagerId],
blockId: BlockId,
numReplicas: Int): List[BlockManagerId] = {
val random = new Random(blockId.hashCode)
logDebug(s"Input peers : ${peers.mkString(", ")}")
val prioritizedPeers = if (peers.size > numReplicas) {
getSampleIds(peers.size, numReplicas, random).map(peers(_))
} else {
if (peers.size < numReplicas) {
logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.")
}
random.shuffle(peers).toList
}
logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}")
prioritizedPeers
}
/**
* Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while
* minimizing space usage
* [[http://math.stackexchange.com/questions/178690/
* whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]]
*
* @param n total number of indices
* @param m number of samples needed
* @param r random number generator
* @return list of m random unique indices
*/
private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = {
val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) =>
val t = r.nextInt(i) + 1
if (set.contains(t)) set + i else set + t
}
// we shuffle the result to ensure a random arrangement within the sample
// to avoid any bias from set implementations
r.shuffle(indices.map(_ - 1).toList)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
/**
* ::DeveloperApi::
* TopologyMapper provides topology information for a given host
* @param conf SparkConf to get required properties, if needed
*/
@DeveloperApi
abstract class TopologyMapper(conf: SparkConf) {
/**
* Gets the topology information given the host name
*
* @param hostname Hostname
* @return topology information for the given hostname. One can use a 'topology delimiter'
* to make this topology information nested.
* For example : ‘/myrack/myhost’, where ‘/’ is the topology delimiter,
* ‘myrack’ is the topology identifier, and ‘myhost’ is the individual host.
* This function only returns the topology information without the hostname.
* This information can be used when choosing executors for block replication
* to discern executors from a different rack than a candidate executor, for example.
*
* An implementation can choose to use empty strings or None in case topology info
* is not available. This would imply that all such executors belong to the same rack.
*/
def getTopologyForHost(hostname: String): Option[String]
}
/**
* A TopologyMapper that assumes all nodes are in the same rack
*/
@DeveloperApi
class DefaultTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging {
override def getTopologyForHost(hostname: String): Option[String] = {
logDebug(s"Got a request for $hostname")
None
}
}
/**
* A simple file based topology mapper. This expects topology information provided as a
* [[java.util.Properties]] file. The name of the file is obtained from SparkConf property
* `spark.storage.replication.topologyFile`. To use this topology mapper, set the
* `spark.storage.replication.topologyMapper` property to
* [[org.apache.spark.storage.FileBasedTopologyMapper]]
* @param conf SparkConf object
*/
@DeveloperApi
class FileBasedTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging {
val topologyFile = conf.getOption("spark.storage.replication.topologyFile")
require(topologyFile.isDefined, "Please specify topology file via " +
"spark.storage.replication.topologyFile for FileBasedTopologyMapper.")
val topologyMap = Utils.getPropertiesFromFile(topologyFile.get)
override def getTopologyForHost(hostname: String): Option[String] = {
val topology = topologyMap.get(hostname)
if (topology.isDefined) {
logDebug(s"$hostname -> ${topology.get}")
} else {
logWarning(s"$hostname does not have any topology information")
}
topology
}
}
......@@ -346,6 +346,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite
}
}
/**
* Test replication of blocks with different storage levels (various combinations of
* memory, disk & serialization). For each storage level, this function tests every store
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import scala.collection.mutable
import org.scalatest.{BeforeAndAfter, Matchers}
import org.apache.spark.{LocalSparkContext, SparkFunSuite}
class BlockReplicationPolicySuite extends SparkFunSuite
with Matchers
with BeforeAndAfter
with LocalSparkContext {
// Implicitly convert strings to BlockIds for test clarity.
private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
/**
* Test if we get the required number of peers when using random sampling from
* RandomBlockReplicationPolicy
*/
test(s"block replication - random block replication policy") {
val numBlockManagers = 10
val storeSize = 1000
val blockManagers = (1 to numBlockManagers).map { i =>
BlockManagerId(s"store-$i", "localhost", 1000 + i, None)
}
val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None)
val replicationPolicy = new RandomBlockReplicationPolicy
val blockId = "test-block"
(1 to 10).foreach {numReplicas =>
logDebug(s"Num replicas : $numReplicas")
val randomPeers = replicationPolicy.prioritize(
candidateBlockManager,
blockManagers,
mutable.HashSet.empty[BlockManagerId],
blockId,
numReplicas
)
logDebug(s"Random peers : ${randomPeers.mkString(", ")}")
assert(randomPeers.toSet.size === numReplicas)
// choosing n peers out of n
val secondPass = replicationPolicy.prioritize(
candidateBlockManager,
randomPeers,
mutable.HashSet.empty[BlockManagerId],
blockId,
numReplicas
)
logDebug(s"Random peers : ${secondPass.mkString(", ")}")
assert(secondPass.toSet.size === numReplicas)
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import java.io.{File, FileOutputStream}
import org.scalatest.{BeforeAndAfter, Matchers}
import org.apache.spark._
import org.apache.spark.util.Utils
class TopologyMapperSuite extends SparkFunSuite
with Matchers
with BeforeAndAfter
with LocalSparkContext {
test("File based Topology Mapper") {
val numHosts = 100
val numRacks = 4
val props = (1 to numHosts).map{i => s"host-$i" -> s"rack-${i % numRacks}"}.toMap
val propsFile = createPropertiesFile(props)
val sparkConf = (new SparkConf(false))
sparkConf.set("spark.storage.replication.topologyFile", propsFile.getAbsolutePath)
val topologyMapper = new FileBasedTopologyMapper(sparkConf)
props.foreach {case (host, topology) =>
val obtainedTopology = topologyMapper.getTopologyForHost(host)
assert(obtainedTopology.isDefined)
assert(obtainedTopology.get === topology)
}
// we get None for hosts not in the file
assert(topologyMapper.getTopologyForHost("host").isEmpty)
cleanup(propsFile)
}
def createPropertiesFile(props: Map[String, String]): File = {
val testFile = new File(Utils.createTempDir(), "TopologyMapperSuite-test").getAbsoluteFile
val fileOS = new FileOutputStream(testFile)
props.foreach{case (k, v) => fileOS.write(s"$k=$v\n".getBytes)}
fileOS.close
testFile
}
def cleanup(testFile: File): Unit = {
testFile.getParentFile.listFiles.filter { file =>
file.getName.startsWith(testFile.getName)
}.foreach { _.delete() }
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment