diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 205d80dd026822e77d2bb2bbedf9d917bb0ddf57..262fd2c9611d077a8c9dbc81be1fbfa1c3f47290 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -272,6 +272,8 @@ object DecisionTreeRunner {
       case Variance => impurity.Variance
     }
 
+    params.checkpointDir.foreach(sc.setCheckpointDir)
+
     val strategy
       = new Strategy(
           algo = params.algo,
@@ -282,7 +284,6 @@ object DecisionTreeRunner {
           minInstancesPerNode = params.minInstancesPerNode,
           minInfoGain = params.minInfoGain,
           useNodeIdCache = params.useNodeIdCache,
-          checkpointDir = params.checkpointDir,
           checkpointInterval = params.checkpointInterval)
     if (params.numTrees == 1) {
       val startTime = System.nanoTime()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 45b0154c5e4cb729e71e74b2d76e8c25c229fe13..db01f2e229e5a1727bd698b1d0f6d5113bae52ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -204,7 +204,6 @@ private class RandomForest (
       Some(NodeIdCache.init(
         data = baggedInput,
         numTrees = numTrees,
-        checkpointDir = strategy.checkpointDir,
         checkpointInterval = strategy.checkpointInterval,
         initVal = 1))
     } else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 3308adb6752ffd8b14a8af00846fdea47a62d3a5..8d5c36da32bdb1af26f490fe6557e88f64ef8734 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  * @param subsamplingRate Fraction of the training data used for learning decision tree.
  * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
  *                      maintain a separate RDD of node Id cache for each row.
- * @param checkpointDir If the node Id cache is used, it will help to checkpoint
- *                      the node Id cache periodically. This is the checkpoint directory
- *                      to be used for the node Id cache.
  * @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
- *                           E.g. 10 means that the cache will get checkpointed every 10 updates.
+ *                           E.g. 10 means that the cache will get checkpointed every 10 updates. If
+ *                           the checkpoint directory is not set in
+ *                           [[org.apache.spark.SparkContext]], this setting is ignored.
  */
 @Experimental
 class Strategy (
@@ -82,7 +81,6 @@ class Strategy (
     @BeanProperty var maxMemoryInMB: Int = 256,
     @BeanProperty var subsamplingRate: Double = 1,
     @BeanProperty var useNodeIdCache: Boolean = false,
-    @BeanProperty var checkpointDir: Option[String] = None,
     @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
 
   def isMulticlassClassification =
@@ -165,7 +163,7 @@ class Strategy (
   def copy: Strategy = {
     new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
       quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
-      maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
+      maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
   }
 }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
index 83011b48b7d9be0692dc08cbed525fd41f7f6d48..bdd0f576b048d47526316431f92a77e3d3dc8d3f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater(
  * The nodeIdsForInstances RDD needs to be updated at each iteration.
  * @param nodeIdsForInstances The initial values in the cache
  *            (should be an Array of all 1's (meaning the root nodes)).
- * @param checkpointDir The checkpoint directory where
- *                      the checkpointed files will be stored.
  * @param checkpointInterval The checkpointing interval
  *                           (how often should the cache be checkpointed.).
  */
 @DeveloperApi
 private[tree] class NodeIdCache(
   var nodeIdsForInstances: RDD[Array[Int]],
-  val checkpointDir: Option[String],
   val checkpointInterval: Int) {
 
   // Keep a reference to a previous node Ids for instances.
@@ -91,12 +88,6 @@ private[tree] class NodeIdCache(
   private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
   private var rddUpdateCount = 0
 
-  // If a checkpoint directory is given, and there's no prior checkpoint directory,
-  // then set the checkpoint directory with the given one.
-  if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
-    nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
-  }
-
   /**
    * Update the node index values in the cache.
    * This updates the RDD and its lineage.
@@ -184,7 +175,6 @@ private[tree] object NodeIdCache {
    * Initialize the node Id cache with initial node Id values.
    * @param data The RDD of training rows.
    * @param numTrees The number of trees that we want to create cache for.
-   * @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
    * @param checkpointInterval The checkpointing interval
    *                           (how often should the cache be checkpointed.).
    * @param initVal The initial values in the cache.
@@ -193,12 +183,10 @@ private[tree] object NodeIdCache {
   def init(
       data: RDD[BaggedPoint[TreePoint]],
       numTrees: Int,
-      checkpointDir: Option[String],
       checkpointInterval: Int,
       initVal: Int = 1): NodeIdCache = {
     new NodeIdCache(
       data.map(_ => Array.fill[Int](numTrees)(initVal)),
-      checkpointDir,
       checkpointInterval)
   }
 }