From f2ba5c6fc3dde81a4d234c75dae2d4e3b46512d1 Mon Sep 17 00:00:00 2001
From: lewuathe <lewuathe@me.com>
Date: Mon, 26 Jan 2015 18:03:21 -0800
Subject: [PATCH] [SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on
 trying to train...

... decision tree model

Labels loaded from libsvm files are mapped to 0.0 if they are negative labels because they should be nonnegative value.

Author: lewuathe <lewuathe@me.com>

Closes #3975 from Lewuathe/map-negative-label-to-positive and squashes the following commits:

12d1d59 [lewuathe] [SPARK-5119] Fix code styles
6d9a18a [lewuathe] [SPARK-5119] Organize test codes
62a150c [lewuathe] [SPARK-5119] Modify Impurities throw exceptions with negatie labels
3336c21 [lewuathe] [SPARK-5119] java.lang.ArrayIndexOutOfBoundsException on trying to train decision tree model
---
 .../spark/mllib/tree/impurity/Entropy.scala   |  5 +++
 .../spark/mllib/tree/impurity/Gini.scala      |  5 +++
 .../spark/mllib/tree/ImpuritySuite.scala      | 42 +++++++++++++++++++
 3 files changed, 52 insertions(+)
 create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 0e02345aa3..b7950e0078 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int)
       throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
         s" but requires label < numClasses (= $statsSize).")
     }
+    if (label < 0) {
+      throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
+        s"but requires label is non-negative.")
+    }
     allStats(offset + label.toInt) += instanceWeight
   }
 
@@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc
     val lbl = label.toInt
     require(lbl < stats.length,
       s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+    require(lbl >= 0, "Entropy does not support negative labels")
     val cnt = count
     if (cnt == 0) {
       0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 7c83cd48e1..c946db9c0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int)
       throw new IllegalArgumentException(s"GiniAggregator given label $label" +
         s" but requires label < numClasses (= $statsSize).")
     }
+    if (label < 0) {
+      throw new IllegalArgumentException(s"GiniAggregator given label $label" +
+        s"but requires label is non-negative.")
+    }
     allStats(offset + label.toInt) += instanceWeight
   }
 
@@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula
     val lbl = label.toInt
     require(lbl < stats.length,
       s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+    require(lbl >= 0, "GiniImpurity does not support negative labels")
     val cnt = count
     if (cnt == 0) {
       0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
new file mode 100644
index 0000000000..92b498580a
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
+ */
+class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
+  test("Gini impurity does not support negative labels") {
+    val gini = new GiniAggregator(2)
+    intercept[IllegalArgumentException] {
+      gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+    }
+  }
+
+  test("Entropy does not support negative labels") {
+    val entropy = new EntropyAggregator(2)
+    intercept[IllegalArgumentException] {
+      entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+    }
+  }
+}
-- 
GitLab