Skip to content
Snippets Groups Projects
Commit 909c6d81 authored by MechCoder's avatar MechCoder Committed by Yanbo Liang
Browse files

[SPARK-16307][ML] Add test to verify the predicted variances of a DT on toy data

## What changes were proposed in this pull request?

The current tests assumes that `impurity.calculate()` returns the variance correctly. It should be better to make the tests independent of this assumption. In other words verify that the variance computed equals the variance computed manually on a small tree.

## How was this patch tested?

The patch is a test....

Author: MechCoder <mks542@nyu.edu>

Closes #13981 from MechCoder/dt_variance.
parent 7e28fabd
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
......@@ -96,6 +97,25 @@ class DecisionTreeRegressorSuite
assert(variance === expectedVariance,
s"Expected variance $expectedVariance but got $variance.")
}
val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int], 0)
dt.setMaxDepth(1)
.setMaxBins(6)
.setSeed(0)
val transformVarDF = dt.fit(varianceDF).transform(varianceDF)
val calculatedVariances = transformVarDF.select(dt.getVarianceCol).collect().map {
case Row(variance: Double) => variance
}
// Since max depth is set to 1, the best split point is that which splits the data
// into (0.0, 1.0, 2.0) and (10.0, 12.0, 14.0). The predicted variance for each
// data point in the left node is 0.667 and for each data point in the right node
// is 2.667
val expectedVariances = Array(0.667, 0.667, 0.667, 2.667, 2.667, 2.667)
calculatedVariances.zip(expectedVariances).foreach { case (actual, expected) =>
assert(actual ~== expected absTol 1e-3)
}
}
test("Feature importance with toy data") {
......
......@@ -182,6 +182,18 @@ private[ml] object TreeTests extends SparkFunSuite {
new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0))
))
/**
* Create some toy data for testing correctness of variance.
*/
def varianceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq(
new LabeledPoint(1.0, Vectors.dense(Array(0.0))),
new LabeledPoint(2.0, Vectors.dense(Array(1.0))),
new LabeledPoint(3.0, Vectors.dense(Array(2.0))),
new LabeledPoint(10.0, Vectors.dense(Array(3.0))),
new LabeledPoint(12.0, Vectors.dense(Array(4.0))),
new LabeledPoint(14.0, Vectors.dense(Array(5.0)))
))
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment