diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 9afb742406ec8352c10beff9f789da29eb3b8412..15fa26e8b527288628248ce2b36162757b105523 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
   DecisionTreeSuite => OldDecisionTreeSuite}
@@ -96,6 +97,25 @@ class DecisionTreeRegressorSuite
       assert(variance === expectedVariance,
         s"Expected variance $expectedVariance but got $variance.")
     }
+
+    val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
+    val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int], 0)
+    dt.setMaxDepth(1)
+      .setMaxBins(6)
+      .setSeed(0)
+    val transformVarDF = dt.fit(varianceDF).transform(varianceDF)
+    val calculatedVariances = transformVarDF.select(dt.getVarianceCol).collect().map {
+      case Row(variance: Double) => variance
+    }
+
+    // Since max depth is set to 1, the best split point is that which splits the data
+    // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each
+    // data point in the left node is 0.667 and for each data point in the right node
+    // is 2.667
+    val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
+    calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) =>
+      assert(actual ~== expected absTol 1e-3)
+    }
   }
 
   test("Feature importance with toy data") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index d2fa8d0d6335d7d217f7060444ca980b468a75e6..c90cb8ca1034c87f7f1148d530dbeb0fc6006019 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -182,6 +182,18 @@ private[ml] object TreeTests extends SparkFunSuite {
     new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
   ))
 
+  /**
+   * Create some toy data for testing correctness of variance.
+   */
+  def varianceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq(
+    new LabeledPoint(1.0, Vectors.dense(Array(0.0))),
+    new LabeledPoint(2.0, Vectors.dense(Array(1.0))),
+    new LabeledPoint(3.0, Vectors.dense(Array(2.0))),
+    new LabeledPoint(10.0, Vectors.dense(Array(3.0))),
+    new LabeledPoint(12.0, Vectors.dense(Array(4.0))),
+    new LabeledPoint(14.0, Vectors.dense(Array(5.0)))
+  ))
+
   /**
    * Mapping from all Params to valid settings which differ from the defaults.
    * This is useful for tests which need to exercise all Params, such as save/load.