Skip to content
Snippets Groups Projects
Commit 01125a11 authored by Andrew Or's avatar Andrew Or Committed by Reynold Xin
Browse files

Clean up CacheManager et al.

**UPDATE**

I have removed the special handling for `StorageLevel.MEMORY_*_SER` for now, because it introduces a potential performance regression. With the latest changes, this PR should include mainly style (code readability) fixes. The only functionality change is the update in `MemoryStore#putBytes` to actually return updated blocks, though this is a minor bug fix.

Now this is mainly a precursor to another PR (once again).

---------
*Old comment*

The deserialized version of a partition may occupy much more space than the serialized version. Therefore, if a partition is to be cached with `StorageLevel.MEMORY_*_SER`, we don't need to fully unroll it into an `ArrayBuffer`, but instead we can unroll it into a potentially much smaller `ByteBuffer`. This may save us from OOMs in this case.

Author: Andrew Or <andrewor14@gmail.com>

Closes #1083 from andrewor14/unroll-them-partitions and squashes the following commits:

7048aa0 [Andrew Or] Merge branch 'master' of github.com:apache/spark into unroll-them-partitions
3d9a366 [Andrew Or] Minor change for readability
d12b95f [Andrew Or] Remove unused imports (minor)
a4c387b [Andrew Or] Merge branch 'master' of github.com:apache/spark into unroll-them-partitions
cf5f565 [Andrew Or] Remove special handling for MEM_*_SER
0091ec0 [Andrew Or] Address review feedback
44ef282 [Andrew Or] Actually return updated blocks in putBytes
2941c89 [Andrew Or] Clean up BlockStore (minor)
a8f181d [Andrew Or] Add special handling for StorageLevel.MEMORY_*_SER
parent 0ac71d12
No related branches found
No related tags found
No related merge requests found
......@@ -20,25 +20,25 @@ package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, BlockManager, BlockStatus, RDDBlockId, StorageLevel}
import org.apache.spark.storage._
/**
* Spark class responsible for passing RDDs split contents to the BlockManager and making
* Spark class responsible for passing RDDs partition contents to the BlockManager and making
* sure a node doesn't load two copies of an RDD at once.
*/
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
/** Keys of RDD splits that are being computed/loaded. */
/** Keys of RDD partitions that are being computed/loaded. */
private val loading = new HashSet[RDDBlockId]()
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
/** Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](
rdd: RDD[T],
split: Partition,
partition: Partition,
context: TaskContext,
storageLevel: StorageLevel): Iterator[T] = {
val key = RDDBlockId(rdd.id, split.index)
val key = RDDBlockId(rdd.id, partition.index)
logDebug(s"Looking for partition $key")
blockManager.get(key) match {
case Some(values) =>
......@@ -46,79 +46,28 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
logInfo(s"Another thread is loading $key, waiting for it to finish...")
while (loading.contains(key)) {
try {
loading.wait()
} catch {
case e: Exception =>
logWarning(s"Got an exception while waiting for another thread to load $key", e)
}
}
logInfo(s"Finished waiting for $key")
/* See whether someone else has successfully loaded it. The main way this would fail
* is for the RDD-level cache eviction policy if someone else has loaded the same RDD
* partition but we didn't want to make space for it. However, that case is unlikely
* because it's unlikely that two threads would work on the same RDD partition. One
* downside of the current code is that threads wait serially if this does happen. */
blockManager.get(key) match {
case Some(values) =>
return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
logInfo(s"Whoever was loading $key failed; we'll try it ourselves")
loading.add(key)
}
} else {
loading.add(key)
}
// Acquire a lock for loading this partition
// If another thread already holds the lock, wait for it to finish return its results
val storedValues = acquireLockForPartition[T](key)
if (storedValues.isDefined) {
return new InterruptibleIterator[T](context, storedValues.get)
}
// Otherwise, we have to load the partition ourselves
try {
// If we got here, we have to load the split
logInfo(s"Partition $key not found, computing it")
val computedValues = rdd.computeOrReadCheckpoint(split, context)
val computedValues = rdd.computeOrReadCheckpoint(partition, context)
// Persist the result, so long as the task is not running locally
// If the task is running locally, do not persist the result
if (context.runningLocally) {
return computedValues
}
// Keep track of blocks with updated statuses
var updatedBlocks = Seq[(BlockId, BlockStatus)]()
val returnValue: Iterator[T] = {
if (storageLevel.useDisk && !storageLevel.useMemory) {
/* In the case that this RDD is to be persisted using DISK_ONLY
* the iterator will be passed directly to the blockManager (rather then
* caching it to an ArrayBuffer first), then the resulting block data iterator
* will be passed back to the user. If the iterator generates a lot of data,
* this means that it doesn't all have to be held in memory at one time.
* This could also apply to MEMORY_ONLY_SER storage, but we need to make sure
* blocks aren't dropped by the block store before enabling that. */
updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true)
blockManager.get(key) match {
case Some(values) =>
values.asInstanceOf[Iterator[T]]
case None =>
logInfo(s"Failure to store $key")
throw new SparkException("Block manager failed to return persisted value")
}
} else {
// In this case the RDD is cached to an array buffer. This will save the results
// if we're dealing with a 'one-time' iterator
val elements = new ArrayBuffer[Any]
elements ++= computedValues
updatedBlocks = blockManager.put(key, elements, storageLevel, tellMaster = true)
elements.iterator.asInstanceOf[Iterator[T]]
}
}
// Update task metrics to include any blocks whose storage status is updated
val metrics = context.taskMetrics
metrics.updatedBlocks = Some(updatedBlocks)
new InterruptibleIterator(context, returnValue)
// Otherwise, cache the values and keep track of any updates in block statuses
val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
val cachedValues = putInBlockManager(key, computedValues, storageLevel, updatedBlocks)
context.taskMetrics.updatedBlocks = Some(updatedBlocks)
new InterruptibleIterator(context, cachedValues)
} finally {
loading.synchronized {
......@@ -128,4 +77,76 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
}
}
}
/**
* Acquire a loading lock for the partition identified by the given block ID.
*
* If the lock is free, just acquire it and return None. Otherwise, another thread is already
* loading the partition, so we wait for it to finish and return the values loaded by the thread.
*/
private def acquireLockForPartition[T](id: RDDBlockId): Option[Iterator[T]] = {
loading.synchronized {
if (!loading.contains(id)) {
// If the partition is free, acquire its lock to compute its value
loading.add(id)
None
} else {
// Otherwise, wait for another thread to finish and return its result
logInfo(s"Another thread is loading $id, waiting for it to finish...")
while (loading.contains(id)) {
try {
loading.wait()
} catch {
case e: Exception =>
logWarning(s"Exception while waiting for another thread to load $id", e)
}
}
logInfo(s"Finished waiting for $id")
val values = blockManager.get(id)
if (!values.isDefined) {
/* The block is not guaranteed to exist even after the other thread has finished.
* For instance, the block could be evicted after it was put, but before our get.
* In this case, we still need to load the partition ourselves. */
logInfo(s"Whoever was loading $id failed; we'll try it ourselves")
loading.add(id)
}
values.map(_.asInstanceOf[Iterator[T]])
}
}
}
/**
* Cache the values of a partition, keeping track of any updates in the storage statuses
* of other blocks along the way.
*/
private def putInBlockManager[T](
key: BlockId,
values: Iterator[T],
storageLevel: StorageLevel,
updatedBlocks: ArrayBuffer[(BlockId, BlockStatus)]): Iterator[T] = {
if (!storageLevel.useMemory) {
/* This RDD is not to be cached in memory, so we can just pass the computed values
* as an iterator directly to the BlockManager, rather than first fully unrolling
* it in memory. The latter option potentially uses much more memory and risks OOM
* exceptions that can be avoided. */
updatedBlocks ++= blockManager.put(key, values, storageLevel, tellMaster = true)
blockManager.get(key) match {
case Some(v) => v.asInstanceOf[Iterator[T]]
case None =>
logInfo(s"Failure to store $key")
throw new BlockException(key, s"Block manager failed to return cached value for $key!")
}
} else {
/* This RDD is to be cached in memory. In this case we cannot pass the computed values
* to the BlockManager as an iterator and expect to read it back later. This is because
* we may end up dropping a partition from memory store before getting it back, e.g.
* when the entirety of the RDD does not fit in memory. */
val elements = new ArrayBuffer[Any]
elements ++= values
updatedBlocks ++= blockManager.put(key, elements, storageLevel, tellMaster = true)
elements.iterator.asInstanceOf[Iterator[T]]
}
}
}
......@@ -17,11 +17,12 @@
package org.apache.spark.scheduler
import scala.language.existentials
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
import scala.language.existentials
import org.apache.spark._
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
......
......@@ -25,10 +25,7 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.shuffle.ShuffleWriter
private[spark] object ShuffleMapTask {
......@@ -150,7 +147,7 @@ private[spark] class ShuffleMapTask(
for (elem <- rdd.iterator(split, context)) {
writer.write(elem.asInstanceOf[Product2[Any, Any]])
}
return writer.stop(success = true).get
writer.stop(success = true).get
} catch {
case e: Exception =>
if (writer != null) {
......
......@@ -24,11 +24,11 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.Logging
/**
* Abstract class to store blocks
* Abstract class to store blocks.
*/
private[spark]
abstract class BlockStore(val blockManager: BlockManager) extends Logging {
def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) : PutResult
private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging {
def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult
/**
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
......@@ -37,11 +37,17 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* @return a PutResult that contains the size of the data, as well as the values put if
* returnValues is true (if not, the result's data field can be null)
*/
def putValues(blockId: BlockId, values: Iterator[Any], level: StorageLevel,
returnValues: Boolean) : PutResult
def putValues(
blockId: BlockId,
values: Iterator[Any],
level: StorageLevel,
returnValues: Boolean): PutResult
def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
returnValues: Boolean) : PutResult
def putValues(
blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean): PutResult
/**
* Return the size of a block in bytes.
......
......@@ -58,11 +58,11 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
val elements = new ArrayBuffer[Any]
elements ++= values
val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
tryToPut(blockId, elements, sizeEstimate, true)
PutResult(sizeEstimate, Left(values.toIterator))
val putAttempt = tryToPut(blockId, elements, sizeEstimate, deserialized = true)
PutResult(sizeEstimate, Left(values.toIterator), putAttempt.droppedBlocks)
} else {
tryToPut(blockId, bytes, bytes.limit, false)
PutResult(bytes.limit(), Right(bytes.duplicate()))
val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false)
PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment