Skip to content
Snippets Groups Projects
Commit 77bcd77e authored by Aaditya Ramesh's avatar Aaditya Ramesh Committed by Shixiong Zhu
Browse files

[SPARK-19525][CORE] Add RDD checkpoint compression support

## What changes were proposed in this pull request?

This PR adds RDD checkpoint compression support and add a new config `spark.checkpoint.compress` to enable/disable it. Credit goes to aramesh117

Closes #17024

## How was this patch tested?

The new unit test.

Author: Shixiong Zhu <shixiong@databricks.com>
Author: Aaditya Ramesh <aramesh@conviva.com>

Closes #17789 from zsxwing/pr17024.
parent ebff519c
No related branches found
No related tags found
No related merge requests found
......@@ -272,4 +272,10 @@ package object config {
.booleanConf
.createWithDefault(false)
private[spark] val CHECKPOINT_COMPRESS =
ConfigBuilder("spark.checkpoint.compress")
.doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " +
"spark.io.compression.codec.")
.booleanConf
.createWithDefault(false)
}
......@@ -18,6 +18,7 @@
package org.apache.spark.rdd
import java.io.{FileNotFoundException, IOException}
import java.util.concurrent.TimeUnit
import scala.reflect.ClassTag
import scala.util.control.NonFatal
......@@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.CHECKPOINT_COMPRESS
import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
......@@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
originalRDD: RDD[T],
checkpointDir: String,
blockSize: Int = -1): ReliableCheckpointRDD[T] = {
val checkpointStartTimeNs = System.nanoTime()
val sc = originalRDD.sparkContext
......@@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging {
writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
}
val checkpointDurationMs =
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs)
logInfo(s"Checkpointing took $checkpointDurationMs ms.")
val newRDD = new ReliableCheckpointRDD[T](
sc, checkpointDirPath.toString, originalRDD.partitioner)
if (newRDD.partitions.length != originalRDD.partitions.length) {
......@@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
val fileOutputStream = if (blockSize < 0) {
fs.create(tempOutputPath, false, bufferSize)
val fileStream = fs.create(tempOutputPath, false, bufferSize)
if (env.conf.get(CHECKPOINT_COMPRESS)) {
CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream)
} else {
fileStream
}
} else {
// This is mainly for testing purpose
fs.create(tempOutputPath, false, bufferSize,
......@@ -273,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val env = SparkEnv.get
val fs = path.getFileSystem(broadcastedConf.value.value)
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
val fileInputStream = fs.open(path, bufferSize)
val fileInputStream = {
val fileStream = fs.open(path, bufferSize)
if (env.conf.get(CHECKPOINT_COMPRESS)) {
CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream)
} else {
fileStream
}
}
val serializer = env.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
......
......@@ -21,8 +21,10 @@ import java.io.File
import scala.reflect.ClassTag
import com.google.common.io.ByteStreams
import org.apache.hadoop.fs.Path
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd._
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
......@@ -580,3 +582,42 @@ object CheckpointSuite {
).asInstanceOf[RDD[(K, Array[Iterable[V]])]]
}
}
class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext {
test("checkpoint compression") {
val checkpointDir = Utils.createTempDir()
try {
val conf = new SparkConf()
.set("spark.checkpoint.compress", "true")
.set("spark.ui.enabled", "false")
sc = new SparkContext("local", "test", conf)
sc.setCheckpointDir(checkpointDir.toString)
val rdd = sc.makeRDD(1 to 20, numSlices = 1)
rdd.checkpoint()
assert(rdd.collect().toSeq === (1 to 20))
// Verify that RDD is checkpointed
assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]])
val checkpointPath = new Path(rdd.getCheckpointFile.get)
val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration)
val checkpointFile =
fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get
// Verify the checkpoint file is compressed, in other words, can be decompressed
val compressedInputStream = CompressionCodec.createCodec(conf)
.compressedInputStream(fs.open(checkpointFile))
try {
ByteStreams.toByteArray(compressedInputStream)
} finally {
compressedInputStream.close()
}
// Verify that the compressed content can be read back
assert(rdd.collect().toSeq === (1 to 20))
} finally {
Utils.deleteRecursively(checkpointDir)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment