diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 9afb742406ec8352c10beff9f789da29eb3b8412..15fa26e8b527288628248ce2b36162757b105523 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} @@ -96,6 +97,25 @@ class DecisionTreeRegressorSuite assert(variance === expectedVariance, s"Expected variance $expectedVariance but got $variance.") } + + val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc) + val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int], 0) + dt.setMaxDepth(1) + .setMaxBins(6) + .setSeed(0) + val transformVarDF = dt.fit(varianceDF).transform(varianceDF) + val calculatedVariances = transformVarDF.select(dt.getVarianceCol).collect().map { + case Row(variance: Double) => variance + } + + // Since max depth is set to 1, the best split point is that which splits the data + // into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each + // data point in the left node is 0.667 and for each data point in the right node + // is 2.667 + val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667) + calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) => + assert(actual ~== expected absTol 1e-3) + } } test("Feature importance with toy data") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index d2fa8d0d6335d7d217f7060444ca980b468a75e6..c90cb8ca1034c87f7f1148d530dbeb0fc6006019 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -182,6 +182,18 @@ private[ml] object TreeTests extends SparkFunSuite { new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) )) + /** + * Create some toy data for testing correctness of variance. + */ + def varianceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(1.0, Vectors.dense(Array(0.0))), + new LabeledPoint(2.0, Vectors.dense(Array(1.0))), + new LabeledPoint(3.0, Vectors.dense(Array(2.0))), + new LabeledPoint(10.0, Vectors.dense(Array(3.0))), + new LabeledPoint(12.0, Vectors.dense(Array(4.0))), + new LabeledPoint(14.0, Vectors.dense(Array(5.0))) + )) + /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load.