Skip to content
Snippets Groups Projects
Commit 574571c8 authored by Yu ISHIKAWA's avatar Yu ISHIKAWA Committed by Xiangrui Meng
Browse files

[SPARK-11515][ML] QuantileDiscretizer should take random seed

cc jkbradley

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #9535 from yu-iskw/SPARK-11515.
parent efb65e09
No related branches found
No related tags found
No related merge requests found
...@@ -24,7 +24,7 @@ import org.apache.spark.annotation.{Experimental, Since} ...@@ -24,7 +24,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._ import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param.{IntParam, _} import org.apache.spark.ml.param.{IntParam, _}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.sql.types.{DoubleType, StructType}
...@@ -33,7 +33,8 @@ import org.apache.spark.util.random.XORShiftRandom ...@@ -33,7 +33,8 @@ import org.apache.spark.util.random.XORShiftRandom
/** /**
* Params for [[QuantileDiscretizer]]. * Params for [[QuantileDiscretizer]].
*/ */
private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol { private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed {
/** /**
* Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
...@@ -73,6 +74,9 @@ final class QuantileDiscretizer(override val uid: String) ...@@ -73,6 +74,9 @@ final class QuantileDiscretizer(override val uid: String)
/** @group setParam */ /** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value)
/** @group setParam */
def setSeed(value: Long): this.type = set(seed, value)
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateParams() validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
...@@ -85,7 +89,8 @@ final class QuantileDiscretizer(override val uid: String) ...@@ -85,7 +89,8 @@ final class QuantileDiscretizer(override val uid: String)
} }
override def fit(dataset: DataFrame): Bucketizer = { override def fit(dataset: DataFrame): Bucketizer = {
val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets)) val samples = QuantileDiscretizer
.getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
.map { case Row(feature: Double) => feature } .map { case Row(feature: Double) => feature }
val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
val splits = QuantileDiscretizer.getSplits(candidates) val splits = QuantileDiscretizer.getSplits(candidates)
...@@ -101,13 +106,13 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi ...@@ -101,13 +106,13 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
/** /**
* Sampling from the given dataset to collect quantile statistics. * Sampling from the given dataset to collect quantile statistics.
*/ */
private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
val totalSamples = dataset.count() val totalSamples = dataset.count()
require(totalSamples > 0, require(totalSamples > 0,
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.") "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
val requiredSamples = math.max(numBins * numBins, 10000) val requiredSamples = math.max(numBins * numBins, 10000)
val fraction = math.min(requiredSamples / dataset.count(), 1.0) val fraction = math.min(requiredSamples / dataset.count(), 1.0)
dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
} }
/** /**
......
...@@ -93,7 +93,7 @@ private object QuantileDiscretizerSuite extends SparkFunSuite { ...@@ -93,7 +93,7 @@ private object QuantileDiscretizerSuite extends SparkFunSuite {
val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
.setNumBuckets(numBucket) .setNumBuckets(numBucket).setSeed(1)
val result = discretizer.fit(df).transform(df) val result = discretizer.fit(df).transform(df)
val transformedFeatures = result.select("result").collect() val transformedFeatures = result.select("result").collect()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment