Skip to content
Snippets Groups Projects
Commit 25998e4d authored by GuoQiang Li's avatar GuoQiang Li Committed by Andrew Or
Browse files

[SPARK-2033] Automatically cleanup checkpoint

Author: GuoQiang Li <witgo@qq.com>

Closes #855 from witgo/cleanup_checkpoint_date and squashes the following commits:

1649850 [GuoQiang Li] review commit
c0087e0 [GuoQiang Li] Automatically cleanup checkpoint
parent dcf8a9f3
No related branches found
No related tags found
No related merge requests found
...@@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference} ...@@ -22,7 +22,7 @@ import java.lang.ref.{ReferenceQueue, WeakReference}
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import org.apache.spark.broadcast.Broadcast import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.{RDDCheckpointData, RDD}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
/** /**
...@@ -33,6 +33,7 @@ private case class CleanRDD(rddId: Int) extends CleanupTask ...@@ -33,6 +33,7 @@ private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask private case class CleanShuffle(shuffleId: Int) extends CleanupTask
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
private case class CleanAccum(accId: Long) extends CleanupTask private case class CleanAccum(accId: Long) extends CleanupTask
private case class CleanCheckpoint(rddId: Int) extends CleanupTask
/** /**
* A WeakReference associated with a CleanupTask. * A WeakReference associated with a CleanupTask.
...@@ -94,12 +95,12 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -94,12 +95,12 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
@volatile private var stopped = false @volatile private var stopped = false
/** Attach a listener object to get information of when objects are cleaned. */ /** Attach a listener object to get information of when objects are cleaned. */
def attachListener(listener: CleanerListener) { def attachListener(listener: CleanerListener): Unit = {
listeners += listener listeners += listener
} }
/** Start the cleaner. */ /** Start the cleaner. */
def start() { def start(): Unit = {
cleaningThread.setDaemon(true) cleaningThread.setDaemon(true)
cleaningThread.setName("Spark Context Cleaner") cleaningThread.setName("Spark Context Cleaner")
cleaningThread.start() cleaningThread.start()
...@@ -108,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -108,7 +109,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/** /**
* Stop the cleaning thread and wait until the thread has finished running its current task. * Stop the cleaning thread and wait until the thread has finished running its current task.
*/ */
def stop() { def stop(): Unit = {
stopped = true stopped = true
// Interrupt the cleaning thread, but wait until the current task has finished before // Interrupt the cleaning thread, but wait until the current task has finished before
// doing so. This guards against the race condition where a cleaning thread may // doing so. This guards against the race condition where a cleaning thread may
...@@ -121,7 +122,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -121,7 +122,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
} }
/** Register a RDD for cleanup when it is garbage collected. */ /** Register a RDD for cleanup when it is garbage collected. */
def registerRDDForCleanup(rdd: RDD[_]) { def registerRDDForCleanup(rdd: RDD[_]): Unit = {
registerForCleanup(rdd, CleanRDD(rdd.id)) registerForCleanup(rdd, CleanRDD(rdd.id))
} }
...@@ -130,17 +131,22 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -130,17 +131,22 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
} }
/** Register a ShuffleDependency for cleanup when it is garbage collected. */ /** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]): Unit = {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
} }
/** Register a Broadcast for cleanup when it is garbage collected. */ /** Register a Broadcast for cleanup when it is garbage collected. */
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { def registerBroadcastForCleanup[T](broadcast: Broadcast[T]): Unit = {
registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
} }
/** Register a RDDCheckpointData for cleanup when it is garbage collected. */
def registerRDDCheckpointDataForCleanup[T](rdd: RDD[_], parentId: Int): Unit = {
registerForCleanup(rdd, CleanCheckpoint(parentId))
}
/** Register an object for cleanup. */ /** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = {
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
} }
...@@ -164,6 +170,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -164,6 +170,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
case CleanAccum(accId) => case CleanAccum(accId) =>
doCleanupAccum(accId, blocking = blockOnCleanupTasks) doCleanupAccum(accId, blocking = blockOnCleanupTasks)
case CleanCheckpoint(rddId) =>
doCleanCheckpoint(rddId)
} }
} }
} }
...@@ -175,7 +183,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -175,7 +183,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
} }
/** Perform RDD cleanup. */ /** Perform RDD cleanup. */
def doCleanupRDD(rddId: Int, blocking: Boolean) { def doCleanupRDD(rddId: Int, blocking: Boolean): Unit = {
try { try {
logDebug("Cleaning RDD " + rddId) logDebug("Cleaning RDD " + rddId)
sc.unpersistRDD(rddId, blocking) sc.unpersistRDD(rddId, blocking)
...@@ -187,7 +195,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -187,7 +195,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
} }
/** Perform shuffle cleanup, asynchronously. */ /** Perform shuffle cleanup, asynchronously. */
def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = {
try { try {
logDebug("Cleaning shuffle " + shuffleId) logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId)
...@@ -200,7 +208,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -200,7 +208,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
} }
/** Perform broadcast cleanup. */ /** Perform broadcast cleanup. */
def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { def doCleanupBroadcast(broadcastId: Long, blocking: Boolean): Unit = {
try { try {
logDebug(s"Cleaning broadcast $broadcastId") logDebug(s"Cleaning broadcast $broadcastId")
broadcastManager.unbroadcast(broadcastId, true, blocking) broadcastManager.unbroadcast(broadcastId, true, blocking)
...@@ -212,7 +220,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -212,7 +220,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
} }
/** Perform accumulator cleanup. */ /** Perform accumulator cleanup. */
def doCleanupAccum(accId: Long, blocking: Boolean) { def doCleanupAccum(accId: Long, blocking: Boolean): Unit = {
try { try {
logDebug("Cleaning accumulator " + accId) logDebug("Cleaning accumulator " + accId)
Accumulators.remove(accId) Accumulators.remove(accId)
...@@ -223,6 +231,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { ...@@ -223,6 +231,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
} }
} }
/** Perform checkpoint cleanup. */
def doCleanCheckpoint(rddId: Int): Unit = {
try {
logDebug("Cleaning rdd checkpoint data " + rddId)
RDDCheckpointData.clearRDDCheckpointData(sc, rddId)
logInfo("Cleaned rdd checkpoint data " + rddId)
}
catch {
case e: Exception => logError("Error cleaning rdd checkpoint data " + rddId, e)
}
}
private def blockManagerMaster = sc.env.blockManager.master private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
......
...@@ -21,7 +21,7 @@ import scala.reflect.ClassTag ...@@ -21,7 +21,7 @@ import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, Partition, SerializableWritable, SparkException} import org.apache.spark._
import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask}
/** /**
...@@ -83,7 +83,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) ...@@ -83,7 +83,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
} }
// Create the output path for the checkpoint // Create the output path for the checkpoint
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
val fs = path.getFileSystem(rdd.context.hadoopConfiguration) val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) { if (!fs.mkdirs(path)) {
throw new SparkException("Failed to create checkpoint path " + path) throw new SparkException("Failed to create checkpoint path " + path)
...@@ -92,8 +92,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) ...@@ -92,8 +92,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
// Save to file, and reload it as an RDD // Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast( val broadcastedConf = rdd.context.broadcast(
new SerializableWritable(rdd.context.hadoopConfiguration)) new SerializableWritable(rdd.context.hadoopConfiguration))
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString) val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
rdd.context.cleaner.foreach { cleaner =>
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
}
}
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
if (newRDD.partitions.length != rdd.partitions.length) { if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException( throw new SparkException(
"Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " +
...@@ -130,5 +135,17 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) ...@@ -130,5 +135,17 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
} }
} }
// Used for synchronization private[spark] object RDDCheckpointData {
private[spark] object RDDCheckpointData def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) }
}
def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = {
rddCheckpointDataPath(sc, rddId).foreach { path =>
val fs = path.getFileSystem(sc.hadoopConfiguration)
if (fs.exists(path)) {
fs.delete(path, true)
}
}
}
}
...@@ -28,7 +28,8 @@ import org.scalatest.concurrent.{PatienceConfiguration, Eventually} ...@@ -28,7 +28,8 @@ import org.scalatest.concurrent.{PatienceConfiguration, Eventually}
import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._ import org.scalatest.time.SpanSugar._
import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._
import org.apache.spark.rdd.{RDDCheckpointData, RDD}
import org.apache.spark.storage._ import org.apache.spark.storage._
import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager
...@@ -205,6 +206,52 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { ...@@ -205,6 +206,52 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase {
postGCTester.assertCleanup() postGCTester.assertCleanup()
} }
test("automatically cleanup checkpoint") {
val checkpointDir = java.io.File.createTempFile("temp", "")
checkpointDir.deleteOnExit()
checkpointDir.delete()
var rdd = newPairRDD
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
rdd.collect()
var rddId = rdd.id
// Confirm the checkpoint directory exists
assert(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).isDefined)
val path = RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get
val fs = path.getFileSystem(sc.hadoopConfiguration)
assert(fs.exists(path))
// the checkpoint is not cleaned by default (without the configuration set)
var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil)
rdd = null // Make RDD out of scope
runGC()
postGCTester.assertCleanup()
assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
sc.stop()
val conf = new SparkConf().setMaster("local[2]").setAppName("cleanupCheckpoint").
set("spark.cleaner.referenceTracking.cleanCheckpoints", "true")
sc = new SparkContext(conf)
rdd = newPairRDD
sc.setCheckpointDir(checkpointDir.toString)
rdd.checkpoint()
rdd.cache()
rdd.collect()
rddId = rdd.id
// Confirm the checkpoint directory exists
assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
// Test that GC causes checkpoint data cleanup after dereferencing the RDD
postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil)
rdd = null // Make RDD out of scope
runGC()
postGCTester.assertCleanup()
assert(!fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get))
}
test("automatically cleanup RDD + shuffle + broadcast") { test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100 val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly val numBroadcasts = 4 // Broadcasts are more costly
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment