From 77bcd77ed5fbd91fe61849cca76a8dffe5e4d6b2 Mon Sep 17 00:00:00 2001
From: Aaditya Ramesh <aramesh@conviva.com>
Date: Fri, 28 Apr 2017 15:28:56 -0700
Subject: [PATCH] [SPARK-19525][CORE] Add RDD checkpoint compression support

## What changes were proposed in this pull request?

This PR adds RDD checkpoint compression support and add a new config `spark.checkpoint.compress` to enable/disable it. Credit goes to aramesh117

Closes #17024

## How was this patch tested?

The new unit test.

Author: Shixiong Zhu <shixiong@databricks.com>
Author: Aaditya Ramesh <aramesh@conviva.com>

Closes #17789 from zsxwing/pr17024.
---
 .../spark/internal/config/package.scala       |  6 +++
 .../spark/rdd/ReliableCheckpointRDD.scala     | 24 ++++++++++-
 .../org/apache/spark/CheckpointSuite.scala    | 41 +++++++++++++++++++
 3 files changed, 69 insertions(+), 2 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 2f0a3064be..7f7921d56f 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -272,4 +272,10 @@ package object config {
       .booleanConf
       .createWithDefault(false)
 
+  private[spark] val CHECKPOINT_COMPRESS =
+    ConfigBuilder("spark.checkpoint.compress")
+      .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " +
+        "spark.io.compression.codec.")
+      .booleanConf
+      .createWithDefault(false)
 }
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
index e0a29b4831..37c67cee55 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.rdd
 
 import java.io.{FileNotFoundException, IOException}
+import java.util.concurrent.TimeUnit
 
 import scala.reflect.ClassTag
 import scala.util.control.NonFatal
@@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark._
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.CHECKPOINT_COMPRESS
+import org.apache.spark.io.CompressionCodec
 import org.apache.spark.util.{SerializableConfiguration, Utils}
 
 /**
@@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
       originalRDD: RDD[T],
       checkpointDir: String,
       blockSize: Int = -1): ReliableCheckpointRDD[T] = {
+    val checkpointStartTimeNs = System.nanoTime()
 
     val sc = originalRDD.sparkContext
 
@@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging {
       writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
     }
 
+    val checkpointDurationMs =
+      TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs)
+    logInfo(s"Checkpointing took $checkpointDurationMs ms.")
+
     val newRDD = new ReliableCheckpointRDD[T](
       sc, checkpointDirPath.toString, originalRDD.partitioner)
     if (newRDD.partitions.length != originalRDD.partitions.length) {
@@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging {
     val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
 
     val fileOutputStream = if (blockSize < 0) {
-      fs.create(tempOutputPath, false, bufferSize)
+      val fileStream = fs.create(tempOutputPath, false, bufferSize)
+      if (env.conf.get(CHECKPOINT_COMPRESS)) {
+        CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream)
+      } else {
+        fileStream
+      }
     } else {
       // This is mainly for testing purpose
       fs.create(tempOutputPath, false, bufferSize,
@@ -273,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging {
     val env = SparkEnv.get
     val fs = path.getFileSystem(broadcastedConf.value.value)
     val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
-    val fileInputStream = fs.open(path, bufferSize)
+    val fileInputStream = {
+      val fileStream = fs.open(path, bufferSize)
+      if (env.conf.get(CHECKPOINT_COMPRESS)) {
+        CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream)
+      } else {
+        fileStream
+      }
+    }
     val serializer = env.serializer.newInstance()
     val deserializeStream = serializer.deserializeStream(fileInputStream)
 
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index b117c7709b..ee70a3399e 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -21,8 +21,10 @@ import java.io.File
 
 import scala.reflect.ClassTag
 
+import com.google.common.io.ByteStreams
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.io.CompressionCodec
 import org.apache.spark.rdd._
 import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
 import org.apache.spark.util.Utils
@@ -580,3 +582,42 @@ object CheckpointSuite {
     ).asInstanceOf[RDD[(K, Array[Iterable[V]])]]
   }
 }
+
+class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext {
+
+  test("checkpoint compression") {
+    val checkpointDir = Utils.createTempDir()
+    try {
+      val conf = new SparkConf()
+        .set("spark.checkpoint.compress", "true")
+        .set("spark.ui.enabled", "false")
+      sc = new SparkContext("local", "test", conf)
+      sc.setCheckpointDir(checkpointDir.toString)
+      val rdd = sc.makeRDD(1 to 20, numSlices = 1)
+      rdd.checkpoint()
+      assert(rdd.collect().toSeq === (1 to 20))
+
+      // Verify that RDD is checkpointed
+      assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]])
+
+      val checkpointPath = new Path(rdd.getCheckpointFile.get)
+      val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration)
+      val checkpointFile =
+        fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get
+
+      // Verify the checkpoint file is compressed, in other words, can be decompressed
+      val compressedInputStream = CompressionCodec.createCodec(conf)
+        .compressedInputStream(fs.open(checkpointFile))
+      try {
+        ByteStreams.toByteArray(compressedInputStream)
+      } finally {
+        compressedInputStream.close()
+      }
+
+      // Verify that the compressed content can be read back
+      assert(rdd.collect().toSeq === (1 to 20))
+    } finally {
+      Utils.deleteRecursively(checkpointDir)
+    }
+  }
+}
-- 
GitLab