Skip to content
Snippets Groups Projects
Commit 23b9863e authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Joseph K. Bradley
Browse files

[SPARK-7559] [MLLIB] Bucketizer should include the right most boundary in the last bucket.

We make special treatment for +inf in `Bucketizer`. This could be simplified by always including the largest split value in the last bucket. E.g., (x1, x2, x3) defines buckets [x1, x2) and [x2, x3]. This shouldn't affect user code much, and there are applications that need to include the right-most value. For example, we can bucketize ratings from 0 to 10 to bad, neutral, and good with splits 0, 4, 6, 10. It may reads weird if the users need to put 0, 4, 6, 10.1 (or 11).

This also update the impl to use `Arrays.binarySearch` and `withClue` in test.

yinxusen jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #6075 from mengxr/SPARK-7559 and squashes the following commits:

e28f910 [Xiangrui Meng] update bucketizer impl
parent 2a41c0d7
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
package org.apache.spark.ml.feature package org.apache.spark.ml.feature
import java.{util => ju}
import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
...@@ -38,18 +41,19 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) ...@@ -38,18 +41,19 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
def this() = this(null) def this() = this(null)
/** /**
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets. * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets.
* A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly * A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which
* increasing. Values at -inf, inf must be explicitly provided to cover all Double values; * also includes y. Splits should be strictly increasing.
* Values at -inf, inf must be explicitly provided to cover all Double values;
* otherwise, values outside the splits specified will be treated as errors. * otherwise, values outside the splits specified will be treated as errors.
* @group param * @group param
*/ */
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits", val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
"Split points for mapping continuous features into buckets. With n splits, there are n+1 " + "Split points for mapping continuous features into buckets. With n+1 splits, there are n " +
"buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " + "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " +
"should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" + "bucket, which also includes y. The splits should be strictly increasing. " +
" all Double values; otherwise, values outside the splits specified will be treated as" + "Values at -inf, inf must be explicitly provided to cover all Double values; " +
" errors.", "otherwise, values outside the splits specified will be treated as errors.",
Bucketizer.checkSplits) Bucketizer.checkSplits)
/** @group getParam */ /** @group getParam */
...@@ -104,28 +108,25 @@ private[feature] object Bucketizer { ...@@ -104,28 +108,25 @@ private[feature] object Bucketizer {
/** /**
* Binary searching in several buckets to place each data point. * Binary searching in several buckets to place each data point.
* @throws RuntimeException if a feature is < splits.head or >= splits.last * @throws SparkException if a feature is < splits.head or > splits.last
*/ */
def binarySearchForBuckets( def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
splits: Array[Double], if (feature == splits.last) {
feature: Double): Double = { splits.length - 2
// Check bounds. We make an exception for +inf so that it can exist in some bin. } else {
if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) { val idx = ju.Arrays.binarySearch(splits, feature)
throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" + if (idx >= 0) {
s" [${splits.head}, ${splits.last}). Check your features, or loosen " + idx
s"the lower/upper bound constraints.")
}
var left = 0
var right = splits.length - 2
while (left < right) {
val mid = (left + right) / 2
val split = splits(mid + 1)
if (feature < split) {
right = mid
} else { } else {
left = mid + 1 val insertPos = -idx - 1
if (insertPos == 0 || insertPos == splits.length) {
throw new SparkException(s"Feature value $feature out of Bucketizer bounds" +
s" [${splits.head}, ${splits.last}]. Check your features, or loosen " +
s"the lower/upper bound constraints.")
} else {
insertPos - 1
}
} }
} }
left
} }
} }
...@@ -57,16 +57,18 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext { ...@@ -57,16 +57,18 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
// Check for exceptions when using a set of invalid feature values. // Check for exceptions when using a set of invalid feature values.
val invalidData1: Array[Double] = Array(-0.9) ++ validData val invalidData1: Array[Double] = Array(-0.9) ++ validData
val invalidData2 = Array(0.5) ++ validData val invalidData2 = Array(0.51) ++ validData
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
intercept[RuntimeException]{ withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
bucketizer.transform(badDF1).collect() intercept[SparkException] {
println("Invalid feature value -0.9 was not caught as an invalid feature!") bucketizer.transform(badDF1).collect()
}
} }
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
intercept[RuntimeException]{ withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
bucketizer.transform(badDF2).collect() intercept[SparkException] {
println("Invalid feature value 0.5 was not caught as an invalid feature!") bucketizer.transform(badDF2).collect()
}
} }
} }
...@@ -137,12 +139,11 @@ private object BucketizerSuite extends FunSuite { ...@@ -137,12 +139,11 @@ private object BucketizerSuite extends FunSuite {
} }
var i = 0 var i = 0
while (i < splits.length - 1) { while (i < splits.length - 1) {
testFeature(splits(i), i) // Split i should fall in bucket i. // Split i should fall in bucket i.
testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i. testFeature(splits(i), i)
// Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf.
testFeature((splits(i) + splits(i + 1)) / 2, i)
i += 1 i += 1
} }
if (splits.last === Double.PositiveInfinity) {
testFeature(Double.PositiveInfinity, splits.length - 2)
}
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment