diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 8fd0ce2f2e26cf79777e9be4860cfa8b5f5656bc..2a294d388182910978650fd4b1b51d3e8c616a40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param.{IntParam, _} -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.{DoubleType, StructType} @@ -33,7 +33,8 @@ import org.apache.spark.util.random.XORShiftRandom /** * Params for [[QuantileDiscretizer]]. */ -private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol { +private[feature] trait QuantileDiscretizerBase extends Params + with HasInputCol with HasOutputCol with HasSeed { /** * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must @@ -73,6 +74,9 @@ final class QuantileDiscretizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + override def transformSchema(schema: StructType): StructType = { validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) @@ -85,7 +89,8 @@ final class QuantileDiscretizer(override val uid: String) } override def fit(dataset: DataFrame): Bucketizer = { - val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets)) + val samples = QuantileDiscretizer + .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) .map { case Row(feature: Double) => feature } val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) val splits = QuantileDiscretizer.getSplits(candidates) @@ -101,13 +106,13 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi /** * Sampling from the given dataset to collect quantile statistics. */ - private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { + private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = { val totalSamples = dataset.count() require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") val requiredSamples = math.max(numBins * numBins, 10000) val fraction = math.min(requiredSamples / dataset.count(), 1.0) - dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() + dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 722f1abde43595287c32273e3dbaad790ff4452f..4fde42972f01b9f563ddb112e67d2149ac479e8e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -93,7 +93,7 @@ private object QuantileDiscretizerSuite extends SparkFunSuite { val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") - .setNumBuckets(numBucket) + .setNumBuckets(numBucket).setSeed(1) val result = discretizer.fit(df).transform(df) val transformedFeatures = result.select("result").collect()