diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index efe8b93d8235a579bb88b5bbf07a2848fcba14e5..5c7993af645af9b157d839e3eb5c43da8070d75d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol with HasSeed { /** - * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must + * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. * default: 2 * @group param @@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getNumBuckets: Int = getOrDefault(numBuckets) + + /** + * Relative error (see documentation for + * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description) + * Must be a number in [0, 1]. + * default: 0.001 + * @group param + */ + val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " + + "for approxQuantile", + ParamValidators.inRange(0.0, 1.0)) + setDefault(relativeError -> 0.001) + + /** @group getParam */ + def getRelativeError: Double = getOrDefault(relativeError) } /** @@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The bin ranges are chosen by taking a sample of the data and dividing it * into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, - * covering all real values. This attempts to find numBuckets partitions based on a sample of data, - * but it may find fewer depending on the data sample values. + * covering all real values. */ @Experimental final class QuantileDiscretizer(override val uid: String) @@ -65,6 +79,9 @@ final class QuantileDiscretizer(override val uid: String) def this() = this(Identifiable.randomUID("quantileDiscretizer")) + /** @group setParam */ + def setRelativeError(value: Double): this.type = set(relativeError, value) + /** @group setParam */ def setNumBuckets(value: Int): this.type = set(numBuckets, value) @@ -89,11 +106,11 @@ final class QuantileDiscretizer(override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): Bucketizer = { - val samples = QuantileDiscretizer - .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) - .map { case Row(feature: Double) => feature } - val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) - val splits = QuantileDiscretizer.getSplits(candidates) + val splits = dataset.stat.approxQuantile($(inputCol), + (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) + splits(0) = Double.NegativeInfinity + splits(splits.length - 1) = Double.PositiveInfinity + val bucketizer = new Bucketizer(uid).setSplits(splits) copyValues(bucketizer.setParent(this)) } @@ -104,92 +121,6 @@ final class QuantileDiscretizer(override val uid: String) @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - /** - * Minimum number of samples required for finding splits, regardless of number of bins. If - * the dataset has fewer rows than this value, the entire dataset will be used. - */ - private[spark] val minSamplesRequired: Int = 10000 - - /** - * Sampling from the given dataset to collect quantile statistics. - */ - private[feature] - def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = { - val totalSamples = dataset.count() - require(totalSamples > 0, - "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") - val requiredSamples = math.max(numBins * numBins, minSamplesRequired) - val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0) - dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) - .collect() - } - - /** - * Compute split points with respect to the sample distribution. - */ - private[feature] - def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = { - val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) => - m + ((x, m.getOrElse(x, 0) + 1)) - } - val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1)) - val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { - valueCounts.dropRight(1).map(_._1) - } else { - val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1)) - val splitsBuilder = mutable.ArrayBuilder.make[Double] - var index = 1 - // currentCount: sum of counts of values that have been visited - var currentCount = valueCounts(0)._2 - // targetCount: target value for `currentCount`. If `currentCount` is closest value to - // `targetCount`, then current value is a split threshold. After finding a split threshold, - // `targetCount` is added by stride. - var targetCount = stride - while (index < valueCounts.length) { - val previousCount = currentCount - currentCount += valueCounts(index)._2 - val previousGap = math.abs(previousCount - targetCount) - val currentGap = math.abs(currentCount - targetCount) - // If adding count of current value to currentCount makes the gap between currentCount and - // targetCount smaller, previous value is a split threshold. - if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 - targetCount += stride - } - index += 1 - } - splitsBuilder.result() - } - } - - /** - * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as - * needed, and adding a default split value of 0 if no good candidates are found. - */ - private[feature] def getSplits(candidates: Array[Double]): Array[Double] = { - val effectiveValues = if (candidates.nonEmpty) { - if (candidates.head == Double.NegativeInfinity - && candidates.last == Double.PositiveInfinity) { - candidates.drop(1).dropRight(1) - } else if (candidates.head == Double.NegativeInfinity) { - candidates.drop(1) - } else if (candidates.last == Double.PositiveInfinity) { - candidates.dropRight(1) - } else { - candidates - } - } else { - candidates - } - - if (effectiveValues.isEmpty) { - Array(Double.NegativeInfinity, 0, Double.PositiveInfinity) - } else { - Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 25fabf64d5594ee57c969a83f371757e9743fc08..8895d630a08791881cd7f1dd70f58b82af44b398 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,78 +17,60 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - import org.apache.spark.ml.feature.QuantileDiscretizerSuite._ - - test("Test quantile discretizer") { - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 10, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 4, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 3, - Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2), - Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity")) + test("Test observed number of buckets and their sizes match expected values") { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 2, - Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1), - Array("-Infinity, 2.0", "2.0, Infinity")) + val datasetSize = 100000 + val numBuckets = 5 + val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) - } + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") - test("Test getting splits") { - val splitTestPoints = Array( - Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity, Double.PositiveInfinity) - -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity), - Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity) - ) - for ((ori, res) <- splitTestPoints) { - assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.") + val relativeError = discretizer.getRelativeError + val isGoodBucket = udf { + (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) } + val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } - test("Test splits on dataset larger than minSamplesRequired") { + test("Test transform method on unseen data") { val sqlCtx = SQLContext.getOrCreate(sc) import sqlCtx.implicits._ - val datasetSize = QuantileDiscretizer.minSamplesRequired + 1 - val numBuckets = 5 - val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input") + val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") + val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") - .setNumBuckets(numBuckets) - .setSeed(1) + .setNumBuckets(5) - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count + val result = discretizer.fit(trainDF).transform(testDF) + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") } test("read/write") { @@ -98,34 +80,17 @@ class QuantileDiscretizerSuite .setNumBuckets(6) testDefaultReadWrite(t) } -} - -private object QuantileDiscretizerSuite extends SparkFunSuite { - def checkDiscretizedData( - sc: SparkContext, - data: Array[Double], - numBucket: Int, - expectedResult: Array[Double], - expectedAttrs: Array[String]): Unit = { + test("Verify resulting model has parent") { val sqlCtx = SQLContext.getOrCreate(sc) import sqlCtx.implicits._ - val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") - val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") - .setNumBuckets(numBucket).setSeed(1) + val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(5) val model = discretizer.fit(df) assert(model.hasParent) - val result = model.transform(df) - - val transformedFeatures = result.select("result").collect() - .map { case Row(transformedFeature: Double) => transformedFeature } - val transformedAttrs = Attribute.fromStructField(result.schema("result")) - .asInstanceOf[NominalAttribute].values.get - - assert(transformedFeatures === expectedResult, - "Transformed features do not equal expected features.") - assert(transformedAttrs === expectedAttrs, - "Transformed attributes do not equal expected attributes.") } }