Skip to content
Snippets Groups Projects
Commit 574571c8 authored by Yu ISHIKAWA's avatar Yu ISHIKAWA Committed by Xiangrui Meng
Browse files

[SPARK-11515][ML] QuantileDiscretizer should take random seed

cc jkbradley

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #9535 from yu-iskw/SPARK-11515.
parent efb65e09
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param.{IntParam, _}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{DoubleType, StructType}
......@@ -33,7 +33,8 @@ import org.apache.spark.util.random.XORShiftRandom
/**
* Params for [[QuantileDiscretizer]].
*/
private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol {
private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed {
/**
* Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
......@@ -73,6 +74,9 @@ final class QuantileDiscretizer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)
override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
......@@ -85,7 +89,8 @@ final class QuantileDiscretizer(override val uid: String)
}
override def fit(dataset: DataFrame): Bucketizer = {
val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets))
val samples = QuantileDiscretizer
.getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
.map { case Row(feature: Double) => feature }
val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
val splits = QuantileDiscretizer.getSplits(candidates)
......@@ -101,13 +106,13 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
/**
* Sampling from the given dataset to collect quantile statistics.
*/
private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = {
private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
val totalSamples = dataset.count()
require(totalSamples > 0,
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
val requiredSamples = math.max(numBins * numBins, 10000)
val fraction = math.min(requiredSamples / dataset.count(), 1.0)
dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
}
/**
......
......@@ -93,7 +93,7 @@ private object QuantileDiscretizerSuite extends SparkFunSuite {
val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
.setNumBuckets(numBucket)
.setNumBuckets(numBucket).setSeed(1)
val result = discretizer.fit(df).transform(df)
val transformedFeatures = result.select("result").collect()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment