Skip to content
Snippets Groups Projects
Commit f2ba5c6f authored by lewuathe's avatar lewuathe Committed by Xiangrui Meng
Browse files

[SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train...

... decision tree model

Labels loaded from libsvm files are mapped to 0.0 if they are negative labels because they should be nonnegative value.

Author: lewuathe <lewuathe@me.com>

Closes #3975 from Lewuathe/map-negative-label-to-positive and squashes the following commits:

12d1d59 [lewuathe] [SPARK-5119] Fix code styles
6d9a18a [lewuathe] [SPARK-5119] Organize test codes
62a150c [lewuathe] [SPARK-5119] Modify Impurities throw exceptions with negatie labels
3336c21 [lewuathe] [SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train decision tree model
parent 661e0fca
No related branches found
No related tags found
No related merge requests found
...@@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int) ...@@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int)
throw new IllegalArgumentException(s"EntropyAggregator given label $label" + throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).") s" but requires label < numClasses (= $statsSize).")
} }
if (label < 0) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s"but requires label is non-negative.")
}
allStats(offset + label.toInt) += instanceWeight allStats(offset + label.toInt) += instanceWeight
} }
...@@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc ...@@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc
val lbl = label.toInt val lbl = label.toInt
require(lbl < stats.length, require(lbl < stats.length,
s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
require(lbl >= 0, "Entropy does not support negative labels")
val cnt = count val cnt = count
if (cnt == 0) { if (cnt == 0) {
0 0
......
...@@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int) ...@@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int)
throw new IllegalArgumentException(s"GiniAggregator given label $label" + throw new IllegalArgumentException(s"GiniAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).") s" but requires label < numClasses (= $statsSize).")
} }
if (label < 0) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
s"but requires label is non-negative.")
}
allStats(offset + label.toInt) += instanceWeight allStats(offset + label.toInt) += instanceWeight
} }
...@@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula ...@@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula
val lbl = label.toInt val lbl = label.toInt
require(lbl < stats.length, require(lbl < stats.length,
s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
require(lbl >= 0, "GiniImpurity does not support negative labels")
val cnt = count val cnt = count
if (cnt == 0) { if (cnt == 0) {
0 0
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree
import org.scalatest.FunSuite
import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
/**
* Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
*/
class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
test("Gini impurity does not support negative labels") {
val gini = new GiniAggregator(2)
intercept[IllegalArgumentException] {
gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
}
}
test("Entropy does not support negative labels") {
val entropy = new EntropyAggregator(2)
intercept[IllegalArgumentException] {
entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment