diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 03eeaa707715bd266961fad8ae95d8a0db7447cf..6737a2f4176c2341ab7299fdf33a0a78770a41d5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
@@ -909,32 +911,39 @@ object DecisionTree extends Serializable with Logging {
         // Iterate over all features.
         var featureIndex = 0
         while (featureIndex < numFeatures) {
-          val numSplits = metadata.numSplits(featureIndex)
-          val numBins = metadata.numBins(featureIndex)
           if (metadata.isContinuous(featureIndex)) {
-            val numSamples = sampledInput.length
+            val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
+            val featureSplits = findSplitsForContinuousFeature(featureSamples,
+              metadata, featureIndex)
+
+            val numSplits = featureSplits.length
+            val numBins = numSplits + 1
+            logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
             splits(featureIndex) = new Array[Split](numSplits)
             bins(featureIndex) = new Array[Bin](numBins)
-            val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
-            val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex)
-            logDebug("stride = " + stride)
-            for (splitIndex <- 0 until numSplits) {
-              val sampleIndex = splitIndex * stride.toInt
-              // Set threshold halfway in between 2 samples.
-              val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
+
+            var splitIndex = 0
+            while (splitIndex < numSplits) {
+              val threshold = featureSplits(splitIndex)
               splits(featureIndex)(splitIndex) =
                 new Split(featureIndex, threshold, Continuous, List())
+              splitIndex += 1
             }
             bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
               splits(featureIndex)(0), Continuous, Double.MinValue)
-            for (splitIndex <- 1 until numSplits) {
+
+            splitIndex = 1
+            while (splitIndex < numSplits) {
               bins(featureIndex)(splitIndex) =
                 new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
                   Continuous, Double.MinValue)
+              splitIndex += 1
             }
             bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
               new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
           } else {
+            val numSplits = metadata.numSplits(featureIndex)
+            val numBins = metadata.numBins(featureIndex)
             // Categorical feature
             val featureArity = metadata.featureArity(featureIndex)
             if (metadata.isUnordered(featureIndex)) {
@@ -1011,4 +1020,77 @@ object DecisionTree extends Serializable with Logging {
     categories
   }
 
+  /**
+   * Find splits for a continuous feature
+   * NOTE: Returned number of splits is set based on `featureSamples` and
+   *       could be different from the specified `numSplits`.
+   *       The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+   * @param featureSamples feature values of each sample
+   * @param metadata decision tree metadata
+   *                 NOTE: `metadata.numbins` will be changed accordingly
+   *                       if there are not enough splits to be found
+   * @param featureIndex feature index to find splits
+   * @return array of splits
+   */
+  private[tree] def findSplitsForContinuousFeature(
+      featureSamples: Array[Double],
+      metadata: DecisionTreeMetadata,
+      featureIndex: Int): Array[Double] = {
+    require(metadata.isContinuous(featureIndex),
+      "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+    val splits = {
+      val numSplits = metadata.numSplits(featureIndex)
+
+      // get count for each distinct value
+      val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
+        m + ((x, m.getOrElse(x, 0) + 1))
+      }
+      // sort distinct values
+      val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+      // if possible splits is not enough or just enough, just return all possible splits
+      val possibleSplits = valueCounts.length
+      if (possibleSplits <= numSplits) {
+        valueCounts.map(_._1)
+      } else {
+        // stride between splits
+        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+        logDebug("stride = " + stride)
+
+        // iterate `valueCount` to find splits
+        val splits = new ArrayBuffer[Double]
+        var index = 1
+        // currentCount: sum of counts of values that have been visited
+        var currentCount = valueCounts(0)._2
+        // targetCount: target value for `currentCount`.
+        // If `currentCount` is closest value to `targetCount`,
+        // then current value is a split threshold.
+        // After finding a split threshold, `targetCount` is added by stride.
+        var targetCount = stride
+        while (index < valueCounts.length) {
+          val previousCount = currentCount
+          currentCount += valueCounts(index)._2
+          val previousGap = math.abs(previousCount - targetCount)
+          val currentGap = math.abs(currentCount - targetCount)
+          // If adding count of current value to currentCount
+          // makes the gap between currentCount and targetCount smaller,
+          // previous value is a split threshold.
+          if (previousGap < currentGap) {
+            splits.append(valueCounts(index - 1)._1)
+            targetCount += stride
+          }
+          index += 1
+        }
+
+        splits.toArray
+      }
+    }
+
+    assert(splits.length > 0)
+    // set number of splits accordingly
+    metadata.setNumSplits(featureIndex, splits.length)
+
+    splits
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 772c02670e5416885f0d91305bf031757ab2dbf0..5bc0f2635c6b198f03853c45e47ca7132ae8d466 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -76,6 +76,17 @@ private[tree] class DecisionTreeMetadata(
     numBins(featureIndex) - 1
   }
 
+
+  /**
+   * Set number of splits for a continuous feature.
+   * For a continuous feature, number of bins is number of splits plus 1.
+   */
+  def setNumSplits(featureIndex: Int, numSplits: Int) {
+    require(isContinuous(featureIndex),
+      s"Only number of bin for a continuous feature can be set.")
+    numBins(featureIndex) = numSplits + 1
+  }
+
   /**
    * Indicates if feature subsampling is being used.
    */
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 98a72b0c4d750f7f504ea2dbeeddf5a61dee256d..8fc5e111bbc1788f88f067ea0ed831e134f6612f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
 import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
@@ -102,6 +102,72 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
   }
 
+  test("find splits for a continuous feature") {
+    // find splits for normal case
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(6), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array.fill(200000)(math.random)
+      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+      assert(splits.length === 5)
+      assert(fakeMetadata.numSplits(0) === 5)
+      assert(fakeMetadata.numBins(0) === 6)
+      // check returned splits are distinct
+      assert(splits.distinct.length === splits.length)
+    }
+
+    // find splits should not return identical splits
+    // when there are not enough split candidates, reduce the number of splits in metadata
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(5), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
+      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+      assert(splits.length === 3)
+      assert(fakeMetadata.numSplits(0) === 3)
+      assert(fakeMetadata.numBins(0) === 4)
+      // check returned splits are distinct
+      assert(splits.distinct.length === splits.length)
+    }
+
+    // find splits when most samples close to the minimum
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(3), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
+      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+      assert(splits.length === 2)
+      assert(fakeMetadata.numSplits(0) === 2)
+      assert(fakeMetadata.numBins(0) === 3)
+      assert(splits(0) === 2.0)
+      assert(splits(1) === 3.0)
+    }
+
+    // find splits when most samples close to the maximum
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(3), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+      val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
+      val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+      assert(splits.length === 1)
+      assert(fakeMetadata.numSplits(0) === 1)
+      assert(fakeMetadata.numBins(0) === 2)
+      assert(splits(0) === 1.0)
+    }
+  }
+
   test("Multiclass classification with unordered categorical features:" +
       " split and bin calculations") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index fb44ceb0f57ee6f1ef17ad3142cbbd130fe8bab6..6b13765b98f41caf188dae9c7c687cd344bfef98 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -93,8 +93,9 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
     val categoricalFeaturesInfo = Map.empty[Int, Int]
     val numTrees = 1
 
-    val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
-      numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+    val strategy = new Strategy(algo = Regression, impurity = Variance,
+      maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+      categoricalFeaturesInfo = categoricalFeaturesInfo)
 
     val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
       featureSubsetStrategy = "auto", seed = 123)
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 0938eebd3a548aa9d31fa7be37edc33163772417..64ee79d83e849a1f46200dc6db6c8ade0307fd7e 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -153,9 +153,9 @@ class DecisionTree(object):
         DecisionTreeModel classifier of depth 1 with 3 nodes
         >>> print model.toDebugString(),  # it already has newline
         DecisionTreeModel classifier of depth 1 with 3 nodes
-          If (feature 0 <= 0.5)
+          If (feature 0 <= 0.0)
            Predict: 0.0
-          Else (feature 0 > 0.5)
+          Else (feature 0 > 0.0)
            Predict: 1.0
         >>> model.predict(array([1.0])) > 0
         True