Skip to content
Snippets Groups Projects
Commit 60b541ee authored by Tathagata Das's avatar Tathagata Das Committed by Andrew Or
Browse files

[SPARK-12004] Preserve the RDD partitioner through RDD checkpointing

The solution is the save the RDD partitioner in a separate file in the RDD checkpoint directory. That is, `<checkpoint dir>/_partitioner`.  In most cases, whether the RDD partitioner was recovered or not, does not affect the correctness, only reduces performance. So this solution makes a best-effort attempt to save and recover the partitioner. If either fails, the checkpointing is not affected. This makes this patch safe and backward compatible.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #9983 from tdas/SPARK-12004.
parent 2cef1cdf
No related branches found
No related tags found
No related merge requests found
......@@ -20,12 +20,12 @@ package org.apache.spark.rdd
import java.io.IOException
import scala.reflect.ClassTag
import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
......@@ -33,8 +33,9 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
*/
private[spark] class ReliableCheckpointRDD[T: ClassTag](
sc: SparkContext,
val checkpointPath: String)
extends CheckpointRDD[T](sc) {
val checkpointPath: String,
_partitioner: Option[Partitioner] = None
) extends CheckpointRDD[T](sc) {
@transient private val hadoopConf = sc.hadoopConfiguration
@transient private val cpath = new Path(checkpointPath)
......@@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
/**
* Return the path of the checkpoint directory this RDD reads data from.
*/
override def getCheckpointFile: Option[String] = Some(checkpointPath)
override val getCheckpointFile: Option[String] = Some(checkpointPath)
override val partitioner: Option[Partitioner] = {
_partitioner.orElse {
ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath)
}
}
/**
* Return partitions described by the files in the checkpoint directory.
......@@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging {
"part-%05d".format(partitionIndex)
}
private def checkpointPartitionerFileName(): String = {
"_partitioner"
}
/**
* Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD.
*/
def writeRDDToCheckpointDirectory[T: ClassTag](
originalRDD: RDD[T],
checkpointDir: String,
blockSize: Int = -1): ReliableCheckpointRDD[T] = {
val sc = originalRDD.sparkContext
// Create the output path for the checkpoint
val checkpointDirPath = new Path(checkpointDir)
val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
if (!fs.mkdirs(checkpointDirPath)) {
throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath")
}
// Save to file, and reload it as an RDD
val broadcastedConf = sc.broadcast(
new SerializableConfiguration(sc.hadoopConfiguration))
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
sc.runJob(originalRDD,
writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)
if (originalRDD.partitioner.nonEmpty) {
writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
}
val newRDD = new ReliableCheckpointRDD[T](
sc, checkpointDirPath.toString, originalRDD.partitioner)
if (newRDD.partitions.length != originalRDD.partitions.length) {
throw new SparkException(
s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})")
}
newRDD
}
/**
* Write this partition's values to a checkpoint file.
* Write a RDD partition's data to a checkpoint file.
*/
def writeCheckpointFile[T: ClassTag](
def writePartitionToCheckpointFile[T: ClassTag](
path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
......@@ -151,6 +200,67 @@ private[spark] object ReliableCheckpointRDD extends Logging {
}
}
/**
* Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort
* basis; any exception while writing the partitioner is caught, logged and ignored.
*/
private def writePartitionerToCheckpointDir(
sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = {
try {
val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val serializeStream = serializer.serializeStream(fileOutputStream)
Utils.tryWithSafeFinally {
serializeStream.writeObject(partitioner)
} {
serializeStream.close()
}
logDebug(s"Written partitioner to $partitionerFilePath")
} catch {
case NonFatal(e) =>
logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath")
}
}
/**
* Read a partitioner from the given RDD checkpoint directory, if it exists.
* This is done on a best-effort basis; any exception while reading the partitioner is
* caught, logged and ignored.
*/
private def readCheckpointedPartitionerFile(
sc: SparkContext,
checkpointDirPath: String): Option[Partitioner] = {
try {
val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
if (fs.exists(partitionerFilePath)) {
val fileInputStream = fs.open(partitionerFilePath, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
val partitioner = Utils.tryWithSafeFinally[Partitioner] {
deserializeStream.readObject[Partitioner]
} {
deserializeStream.close()
}
logDebug(s"Read partitioner from $partitionerFilePath")
Some(partitioner)
} else {
logDebug("No partitioner file")
None
}
} catch {
case NonFatal(e) =>
logWarning(s"Error reading partitioner from $checkpointDirPath, " +
s"partitioner will not be recovered which may lead to performance loss", e)
None
}
}
/**
* Read the content of the specified checkpoint file.
*/
......
......@@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
* This is called immediately after the first action invoked on this RDD has completed.
*/
protected override def doCheckpoint(): CheckpointRDD[T] = {
// Create the output path for the checkpoint
val path = new Path(cpDir)
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException(s"Failed to create checkpoint path $cpDir")
}
// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableConfiguration(rdd.context.hadoopConfiguration))
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _)
val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir)
if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException(
s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
s"number of partitions from original RDD $rdd(${rdd.partitions.length})")
}
val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)
// Optionally clean our checkpoint files if the reference is out of scope
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
......@@ -83,7 +65,6 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
}
logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")
newRDD
}
......
......@@ -21,7 +21,8 @@ import java.io.File
import scala.reflect.ClassTag
import org.apache.spark.CheckpointSuite._
import org.apache.hadoop.fs.Path
import org.apache.spark.rdd._
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
......@@ -74,8 +75,10 @@ trait RDDCheckpointTester { self: SparkFunSuite =>
// Test whether the checkpoint file has been created
if (reliableCheckpoint) {
assert(
collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
assert(operatedRDD.getCheckpointFile.nonEmpty)
val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)
assert(collectFunc(recoveredRDD) === result)
assert(recoveredRDD.partitioner === operatedRDD.partitioner)
}
// Test whether dependencies have been changed from its earlier parent RDD
......@@ -211,9 +214,14 @@ trait RDDCheckpointTester { self: SparkFunSuite =>
}
/** Run a test twice, once for local checkpointing and once for reliable checkpointing. */
protected def runTest(name: String)(body: Boolean => Unit): Unit = {
protected def runTest(
name: String,
skipLocalCheckpoint: Boolean = false
)(body: Boolean => Unit): Unit = {
test(name + " [reliable checkpoint]")(body(true))
test(name + " [local checkpoint]")(body(false))
if (!skipLocalCheckpoint) {
test(name + " [local checkpoint]")(body(false))
}
}
/**
......@@ -264,6 +272,49 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
assert(flatMappedRDD.collect() === result)
}
runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean =>
def testPartitionerCheckpointing(
partitioner: Partitioner,
corruptPartitionerFile: Boolean = false
): Unit = {
val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner)
rddWithPartitioner.checkpoint()
rddWithPartitioner.count()
assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty,
"checkpointing was not successful")
if (corruptPartitionerFile) {
// Overwrite the partitioner file with garbage data
val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get)
val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration)
val partitionerFile = fs.listStatus(checkpointDir)
.find(_.getPath.getName.contains("partitioner"))
.map(_.getPath)
require(partitionerFile.nonEmpty, "could not find the partitioner file for testing")
val output = fs.create(partitionerFile.get, true)
output.write(100)
output.close()
}
val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get)
assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered")
if (!corruptPartitionerFile) {
assert(newRDD.partitioner != None, "partitioner not recovered")
assert(newRDD.partitioner === rddWithPartitioner.partitioner,
"recovered partitioner does not match")
} else {
assert(newRDD.partitioner == None, "partitioner unexpectedly recovered")
}
}
testPartitionerCheckpointing(partitioner)
// Test that corrupted partitioner file does not prevent recovery of RDD
testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true)
}
runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean =>
testRDD(_.map(x => x.toString), reliableCheckpoint)
testRDD(_.flatMap(x => 1 to x), reliableCheckpoint)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment