Skip to content
Snippets Groups Projects
Commit 89a41c5b authored by Oliver Pierson's avatar Oliver Pierson Committed by Xiangrui Meng
Browse files

[SPARK-13600][MLLIB] Use approxQuantile from DataFrame stats in QuantileDiscretizer

## What changes were proposed in this pull request?
QuantileDiscretizer can return an unexpected number of buckets in certain cases.  This PR proposes to fix this issue and also refactor QuantileDiscretizer to use approxQuantiles from DataFrame stats functions.
## How was this patch tested?
QuantileDiscretizerSuite unit tests (some existing tests will change or even be removed in this PR)

Author: Oliver Pierson <ocp@gatech.edu>

Closes #11553 from oliverpierson/SPARK-13600.
parent 2dacc81e
No related branches found
No related tags found
No related merge requests found
...@@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params ...@@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed { with HasInputCol with HasOutputCol with HasSeed {
/** /**
* Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must * Number of buckets (quantiles, or categories) into which data points are grouped. Must
* be >= 2. * be >= 2.
* default: 2 * default: 2
* @group param * @group param
...@@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params ...@@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params
/** @group getParam */ /** @group getParam */
def getNumBuckets: Int = getOrDefault(numBuckets) def getNumBuckets: Int = getOrDefault(numBuckets)
/**
* Relative error (see documentation for
* [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description)
* Must be a number in [0, 1].
* default: 0.001
* @group param
*/
val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " +
"for approxQuantile",
ParamValidators.inRange(0.0, 1.0))
setDefault(relativeError -> 0.001)
/** @group getParam */
def getRelativeError: Double = getOrDefault(relativeError)
} }
/** /**
...@@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params ...@@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The bin ranges are chosen by taking a sample of the data and dividing it * categorical features. The bin ranges are chosen by taking a sample of the data and dividing it
* into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, * into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity,
* covering all real values. This attempts to find numBuckets partitions based on a sample of data, * covering all real values.
* but it may find fewer depending on the data sample values.
*/ */
@Experimental @Experimental
final class QuantileDiscretizer(override val uid: String) final class QuantileDiscretizer(override val uid: String)
...@@ -65,6 +79,9 @@ final class QuantileDiscretizer(override val uid: String) ...@@ -65,6 +79,9 @@ final class QuantileDiscretizer(override val uid: String)
def this() = this(Identifiable.randomUID("quantileDiscretizer")) def this() = this(Identifiable.randomUID("quantileDiscretizer"))
/** @group setParam */
def setRelativeError(value: Double): this.type = set(relativeError, value)
/** @group setParam */ /** @group setParam */
def setNumBuckets(value: Int): this.type = set(numBuckets, value) def setNumBuckets(value: Int): this.type = set(numBuckets, value)
...@@ -89,11 +106,11 @@ final class QuantileDiscretizer(override val uid: String) ...@@ -89,11 +106,11 @@ final class QuantileDiscretizer(override val uid: String)
@Since("2.0.0") @Since("2.0.0")
override def fit(dataset: Dataset[_]): Bucketizer = { override def fit(dataset: Dataset[_]): Bucketizer = {
val samples = QuantileDiscretizer val splits = dataset.stat.approxQuantile($(inputCol),
.getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
.map { case Row(feature: Double) => feature } splits(0) = Double.NegativeInfinity
val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) splits(splits.length - 1) = Double.PositiveInfinity
val splits = QuantileDiscretizer.getSplits(candidates)
val bucketizer = new Bucketizer(uid).setSplits(splits) val bucketizer = new Bucketizer(uid).setSplits(splits)
copyValues(bucketizer.setParent(this)) copyValues(bucketizer.setParent(this))
} }
...@@ -104,92 +121,6 @@ final class QuantileDiscretizer(override val uid: String) ...@@ -104,92 +121,6 @@ final class QuantileDiscretizer(override val uid: String)
@Since("1.6.0") @Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
/**
* Minimum number of samples required for finding splits, regardless of number of bins. If
* the dataset has fewer rows than this value, the entire dataset will be used.
*/
private[spark] val minSamplesRequired: Int = 10000
/**
* Sampling from the given dataset to collect quantile statistics.
*/
private[feature]
def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = {
val totalSamples = dataset.count()
require(totalSamples > 0,
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
.collect()
}
/**
* Compute split points with respect to the sample distribution.
*/
private[feature]
def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
m + ((x, m.getOrElse(x, 0) + 1))
}
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1))
val possibleSplits = valueCounts.length - 1
if (possibleSplits <= numSplits) {
valueCounts.dropRight(1).map(_._1)
} else {
val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1))
val splitsBuilder = mutable.ArrayBuilder.make[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCounts(0)._2
// targetCount: target value for `currentCount`. If `currentCount` is closest value to
// `targetCount`, then current value is a split threshold. After finding a split threshold,
// `targetCount` is added by stride.
var targetCount = stride
while (index < valueCounts.length) {
val previousCount = currentCount
currentCount += valueCounts(index)._2
val previousGap = math.abs(previousCount - targetCount)
val currentGap = math.abs(currentCount - targetCount)
// If adding count of current value to currentCount makes the gap between currentCount and
// targetCount smaller, previous value is a split threshold.
if (previousGap < currentGap) {
splitsBuilder += valueCounts(index - 1)._1
targetCount += stride
}
index += 1
}
splitsBuilder.result()
}
}
/**
* Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
* needed, and adding a default split value of 0 if no good candidates are found.
*/
private[feature] def getSplits(candidates: Array[Double]): Array[Double] = {
val effectiveValues = if (candidates.nonEmpty) {
if (candidates.head == Double.NegativeInfinity
&& candidates.last == Double.PositiveInfinity) {
candidates.drop(1).dropRight(1)
} else if (candidates.head == Double.NegativeInfinity) {
candidates.drop(1)
} else if (candidates.last == Double.PositiveInfinity) {
candidates.dropRight(1)
} else {
candidates
}
} else {
candidates
}
if (effectiveValues.isEmpty) {
Array(Double.NegativeInfinity, 0, Double.PositiveInfinity)
} else {
Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
}
}
@Since("1.6.0") @Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path) override def load(path: String): QuantileDiscretizer = super.load(path)
} }
...@@ -17,78 +17,60 @@ ...@@ -17,78 +17,60 @@
package org.apache.spark.ml.feature package org.apache.spark.ml.feature
import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions.udf
class QuantileDiscretizerSuite class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import org.apache.spark.ml.feature.QuantileDiscretizerSuite._ test("Test observed number of buckets and their sizes match expected values") {
val sqlCtx = SQLContext.getOrCreate(sc)
test("Test quantile discretizer") { import sqlCtx.implicits._
checkDiscretizedData(sc,
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
10,
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
checkDiscretizedData(sc,
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
4,
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
checkDiscretizedData(sc,
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
3,
Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))
checkDiscretizedData(sc, val datasetSize = 100000
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), val numBuckets = 5
2, val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input")
Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1), val discretizer = new QuantileDiscretizer()
Array("-Infinity, 2.0", "2.0, Infinity")) .setInputCol("input")
.setOutputCol("result")
.setNumBuckets(numBuckets)
val result = discretizer.fit(df).transform(df)
} val observedNumBuckets = result.select("result").distinct.count
assert(observedNumBuckets === numBuckets,
"Observed number of buckets does not equal expected number of buckets.")
test("Test getting splits") { val relativeError = discretizer.getRelativeError
val splitTestPoints = Array( val isGoodBucket = udf {
Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
Array(Double.NegativeInfinity, Double.PositiveInfinity)
-> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
)
for ((ori, res) <- splitTestPoints) {
assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
} }
val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count
assert(numGoodBuckets === numBuckets,
"Bucket sizes are not within expected relative error tolerance.")
} }
test("Test splits on dataset larger than minSamplesRequired") { test("Test transform method on unseen data") {
val sqlCtx = SQLContext.getOrCreate(sc) val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._ import sqlCtx.implicits._
val datasetSize = QuantileDiscretizer.minSamplesRequired + 1 val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
val numBuckets = 5 val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input")
val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer() val discretizer = new QuantileDiscretizer()
.setInputCol("input") .setInputCol("input")
.setOutputCol("result") .setOutputCol("result")
.setNumBuckets(numBuckets) .setNumBuckets(5)
.setSeed(1)
val result = discretizer.fit(df).transform(df) val result = discretizer.fit(trainDF).transform(testDF)
val observedNumBuckets = result.select("result").distinct.count val firstBucketSize = result.filter(result("result") === 0.0).count
val lastBucketSize = result.filter(result("result") === 4.0).count
assert(observedNumBuckets === numBuckets, assert(firstBucketSize === 30L,
"Observed number of buckets does not equal expected number of buckets.") s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
assert(lastBucketSize === 31L,
s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
} }
test("read/write") { test("read/write") {
...@@ -98,34 +80,17 @@ class QuantileDiscretizerSuite ...@@ -98,34 +80,17 @@ class QuantileDiscretizerSuite
.setNumBuckets(6) .setNumBuckets(6)
testDefaultReadWrite(t) testDefaultReadWrite(t)
} }
}
private object QuantileDiscretizerSuite extends SparkFunSuite {
def checkDiscretizedData( test("Verify resulting model has parent") {
sc: SparkContext,
data: Array[Double],
numBucket: Int,
expectedResult: Array[Double],
expectedAttrs: Array[String]): Unit = {
val sqlCtx = SQLContext.getOrCreate(sc) val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._ import sqlCtx.implicits._
val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") val discretizer = new QuantileDiscretizer()
.setNumBuckets(numBucket).setSeed(1) .setInputCol("input")
.setOutputCol("result")
.setNumBuckets(5)
val model = discretizer.fit(df) val model = discretizer.fit(df)
assert(model.hasParent) assert(model.hasParent)
val result = model.transform(df)
val transformedFeatures = result.select("result").collect()
.map { case Row(transformedFeature: Double) => transformedFeature }
val transformedAttrs = Attribute.fromStructField(result.schema("result"))
.asInstanceOf[NominalAttribute].values.get
assert(transformedFeatures === expectedResult,
"Transformed features do not equal expected features.")
assert(transformedAttrs === expectedAttrs,
"Transformed attributes do not equal expected attributes.")
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment