Skip to content
Snippets Groups Projects
Commit 7058a539 authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

[SPARK-2796] [mllib] DecisionTree bug fix: ordered categorical features

Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature.

Added new test to DecisionTreeSuite to catch this: "regression stump with categorical variables of arity 2"

Bug fix: Modified upper bound discussed above.

Also: Small improvements to coding style in DecisionTree.

CC mengxr manishamde

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #1720 from jkbradley/decisiontree-bugfix2 and squashes the following commits:

225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature.
parent d88e6956
No related branches found
No related tags found
No related merge requests found
......@@ -498,7 +498,7 @@ object DecisionTree extends Serializable with Logging {
val bin = binForFeatures(mid)
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
if ((lowThreshold < feature) && (highThreshold >= feature)){
if ((lowThreshold < feature) && (highThreshold >= feature)) {
return mid
}
else if (lowThreshold >= feature) {
......@@ -522,28 +522,36 @@ object DecisionTree extends Serializable with Logging {
}
/**
* Sequential search helper method to find bin for categorical feature.
* Sequential search helper method to find bin for categorical feature
* (for classification and regression).
*/
def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
val featureValue = labeledPoint.features(featureIndex)
var binIndex = 0
while (binIndex < numCategoricalBins) {
while (binIndex < featureCategories) {
val bin = bins(featureIndex)(binIndex)
val categories = bin.highSplit.categories
val features = labeledPoint.features
if (categories.contains(features(featureIndex))) {
if (categories.contains(featureValue)) {
return binIndex
}
binIndex += 1
}
if (featureValue < 0 || featureValue >= featureCategories) {
throw new IllegalArgumentException(
s"DecisionTree given invalid data:" +
s" Feature $featureIndex is categorical with values in" +
s" {0,...,${featureCategories - 1}," +
s" but a data point gives it value $featureValue.\n" +
" Bad data point: " + labeledPoint.toString)
}
-1
}
if (isFeatureContinuous) {
// Perform binary search for finding bin for continuous features.
val binIndex = binarySearchForBins()
if (binIndex == -1){
if (binIndex == -1) {
throw new UnknownError("no bin was found for continuous variable.")
}
binIndex
......@@ -555,10 +563,10 @@ object DecisionTree extends Serializable with Logging {
if (isUnorderedFeature) {
sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
} else {
sequentialBinSearchForOrderedCategoricalFeatureInClassification()
sequentialBinSearchForOrderedCategoricalFeature()
}
}
if (binIndex == -1){
if (binIndex == -1) {
throw new UnknownError("no bin was found for categorical variable.")
}
binIndex
......@@ -642,11 +650,12 @@ object DecisionTree extends Serializable with Logging {
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
val aggShift = numClasses * numBins * numFeatures * nodeIndex
val aggIndex
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
val labelInt = label.toInt
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
val aggIndex =
numClasses * numBins * numFeatures * nodeIndex +
numClasses * numBins * featureIndex +
numClasses * arr(arrIndex).toInt +
label.toInt
agg(aggIndex) += 1
}
/**
......@@ -1127,7 +1136,7 @@ object DecisionTree extends Serializable with Logging {
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
var featureIndex = 0
while (featureIndex < numFeatures) {
if (isMulticlassClassificationWithCategoricalFeatures){
if (isMulticlassClassificationWithCategoricalFeatures) {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
......@@ -1393,7 +1402,7 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures){
while (featureIndex < numFeatures) {
// Check whether the feature is continuous.
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
......@@ -1513,7 +1522,7 @@ object DecisionTree extends Serializable with Logging {
if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
for (index <- 1 until numBins - 1){
for (index <- 1 until numBins - 1) {
val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
Continuous, Double.MinValue)
bins(featureIndex)(index) = bin
......
......@@ -42,6 +42,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(accuracy >= requiredAccuracy)
}
def validateRegressor(
model: DecisionTreeModel,
input: Seq[LabeledPoint],
requiredMSE: Double) {
val predictions = input.map(x => model.predict(x.features))
val squaredError = predictions.zip(input).map { case (prediction, expected) =>
(prediction - expected.label) * (prediction - expected.label)
}.sum
val mse = squaredError / input.length
assert(mse <= requiredMSE)
}
test("split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
......@@ -454,6 +466,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.impurity > 0.2)
}
test("regression stump with categorical variables of arity 2") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Regression,
Variance,
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
val model = DecisionTree.train(rdd, strategy)
validateRegressor(model, arr, 0.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
}
test("stump with fixed label 0 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment