Skip to content
Snippets Groups Projects
Commit f9c5b0a6 authored by Tathagata Das's avatar Tathagata Das
Browse files

Changed checkpoint writing and reading process.

parent 51841419
No related branches found
No related tags found
No related merge requests found
package spark package spark
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import rdd.CoalescedRDD import rdd.{CheckpointRDD, CoalescedRDD}
import scheduler.{ResultTask, ShuffleMapTask} import scheduler.{ResultTask, ShuffleMapTask}
/** /**
...@@ -55,30 +55,13 @@ extends Logging with Serializable { ...@@ -55,30 +55,13 @@ extends Logging with Serializable {
} }
// Save to file, and reload it as an RDD // Save to file, and reload it as an RDD
val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString
rdd.saveAsObjectFile(file) rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
val newRDD = new CheckpointRDD[T](rdd.context, path)
val newRDD = {
val hadoopRDD = rdd.context.objectFile[T](file, rdd.splits.size)
val oldSplits = rdd.splits.size
val newSplits = hadoopRDD.splits.size
logDebug("RDD splits = " + oldSplits + " --> " + newSplits)
if (newSplits < oldSplits) {
throw new Exception("# splits after checkpointing is less than before " +
"[" + oldSplits + " --> " + newSplits)
} else if (newSplits > oldSplits) {
new CoalescedRDD(hadoopRDD, rdd.splits.size)
} else {
hadoopRDD
}
}
logDebug("New RDD has " + newRDD.splits.size + " splits")
// Change the dependencies and splits of the RDD // Change the dependencies and splits of the RDD
RDDCheckpointData.synchronized { RDDCheckpointData.synchronized {
cpFile = Some(file) cpFile = Some(path)
cpRDD = Some(newRDD) cpRDD = Some(newRDD)
rdd.changeDependencies(newRDD) rdd.changeDependencies(newRDD)
cpState = Checkpointed cpState = Checkpointed
......
package spark.rdd
import spark._
import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.{NullWritable, BytesWritable}
import org.apache.hadoop.util.ReflectionUtils
import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
override val index: Int = idx
}
class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
extends RDD[T](sc, Nil) {
@transient val path = new Path(checkpointPath)
@transient val fs = path.getFileSystem(new Configuration())
@transient val splits_ : Array[Split] = {
val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
}
override def getSplits = splits_
override def getPreferredLocations(split: Split): Seq[String] = {
val status = fs.getFileStatus(path)
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
override def compute(split: Split): Iterator[T] = {
CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile)
}
override def checkpoint() {
// Do nothing. Hadoop RDD should not be checkpointed.
}
}
private[spark] object CheckpointRDD extends Logging {
def splitIdToFileName(splitId: Int): String = {
val numfmt = NumberFormat.getInstance()
numfmt.setMinimumIntegerDigits(5)
numfmt.setGroupingUsed(false)
"part-" + numfmt.format(splitId)
}
def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(new Configuration())
val finalOutputName = splitIdToFileName(context.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
tempOutputPath + " already exists")
}
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileOutputStream = if (blockSize < 0) {
fs.create(tempOutputPath, false, bufferSize)
} else {
// This is mainly for testing purpose
fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
}
val serializer = SparkEnv.get.serializer.newInstance()
val serializeStream = serializer.serializeStream(fileOutputStream)
serializeStream.writeAll(iterator)
fileOutputStream.close()
if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.delete(finalOutputPath, true)) {
throw new IOException("Checkpoint failed: failed to delete earlier output of task "
+ context.attemptId);
}
if (!fs.rename(tempOutputPath, finalOutputPath)) {
throw new IOException("Checkpoint failed: failed to save output of task: "
+ context.attemptId)
}
}
}
def readFromFile[T](path: String): Iterator[T] = {
val inputPath = new Path(path)
val fs = inputPath.getFileSystem(new Configuration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(inputPath, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
deserializeStream.asIterator.asInstanceOf[Iterator[T]]
}
// Test whether CheckpointRDD generate expected number of splits despite
// each split file having multiple blocks. This needs to be run on a
// cluster (mesos or standalone) using HDFS.
def main(args: Array[String]) {
import spark._
val Array(cluster, hdfsPath) = args
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
val fs = path.getFileSystem(new Configuration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same")
assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same")
fs.delete(path)
}
}
...@@ -25,8 +25,7 @@ import spark.Split ...@@ -25,8 +25,7 @@ import spark.Split
* A Spark split class that wraps around a Hadoop InputSplit. * A Spark split class that wraps around a Hadoop InputSplit.
*/ */
private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
extends Split extends Split {
with Serializable {
val inputSplit = new SerializableWritable[InputSplit](s) val inputSplit = new SerializableWritable[InputSplit](s)
...@@ -117,6 +116,6 @@ class HadoopRDD[K, V]( ...@@ -117,6 +116,6 @@ class HadoopRDD[K, V](
} }
override def checkpoint() { override def checkpoint() {
// Do nothing. Hadoop RDD cannot be checkpointed. // Do nothing. Hadoop RDD should not be checkpointed.
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment