Skip to content
Snippets Groups Projects
Commit 529d6ce8 authored by Xusen Yin's avatar Xusen Yin Committed by Joseph K. Bradley
Browse files

[SPARK-14181] TrainValidationSplit should have HasSeed

https://issues.apache.org/jira/browse/SPARK-14181

TrainValidationSplit should have HasSeed for the random split of RDD. I also changed the random split from the RDD function to the DataFrame function.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #11985 from yinxusen/SPARK-14181.
parent bdabfd43
No related branches found
No related tags found
No related merge requests found
...@@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging ...@@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
...@@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType ...@@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType
/** /**
* Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
*/ */
private[ml] trait TrainValidationSplitParams extends ValidatorParams { private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed {
/** /**
* Param for ratio between train and validation data. Must be between 0 and 1. * Param for ratio between train and validation data. Must be between 0 and 1.
* Default: 0.75 * Default: 0.75
...@@ -80,6 +81,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St ...@@ -80,6 +81,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
@Since("1.5.0") @Since("1.5.0")
def setTrainRatio(value: Double): this.type = set(trainRatio, value) def setTrainRatio(value: Double): this.type = set(trainRatio, value)
/** @group setParam */
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
@Since("1.5.0") @Since("1.5.0")
override def fit(dataset: DataFrame): TrainValidationSplitModel = { override def fit(dataset: DataFrame): TrainValidationSplitModel = {
val schema = dataset.schema val schema = dataset.schema
...@@ -91,10 +96,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St ...@@ -91,10 +96,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val numModels = epm.length val numModels = epm.length
val metrics = new Array[Double](epm.length) val metrics = new Array[Double](epm.length)
val Array(training, validation) = val Array(trainingDataset, validationDataset) =
dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() trainingDataset.cache()
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() validationDataset.cache()
// multi-model training // multi-model training
logDebug(s"Train split with multiple sets of parameters.") logDebug(s"Train split with multiple sets of parameters.")
......
...@@ -48,6 +48,7 @@ class TrainValidationSplitSuite ...@@ -48,6 +48,7 @@ class TrainValidationSplitSuite
.setEstimatorParamMaps(lrParamMaps) .setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval) .setEvaluator(eval)
.setTrainRatio(0.5) .setTrainRatio(0.5)
.setSeed(42L)
val cvModel = cv.fit(dataset) val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(cv.getTrainRatio === 0.5) assert(cv.getTrainRatio === 0.5)
...@@ -72,6 +73,7 @@ class TrainValidationSplitSuite ...@@ -72,6 +73,7 @@ class TrainValidationSplitSuite
.setEstimatorParamMaps(lrParamMaps) .setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval) .setEvaluator(eval)
.setTrainRatio(0.5) .setTrainRatio(0.5)
.setSeed(42L)
val cvModel = cv.fit(dataset) val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent.getRegParam === 0.001) assert(parent.getRegParam === 0.001)
...@@ -120,6 +122,7 @@ class TrainValidationSplitSuite ...@@ -120,6 +122,7 @@ class TrainValidationSplitSuite
.setEvaluator(evaluator) .setEvaluator(evaluator)
.setTrainRatio(0.5) .setTrainRatio(0.5)
.setEstimatorParamMaps(paramMaps) .setEstimatorParamMaps(paramMaps)
.setSeed(42L)
val tvs2 = testDefaultReadWrite(tvs, testParams = false) val tvs2 = testDefaultReadWrite(tvs, testParams = false)
...@@ -140,6 +143,7 @@ class TrainValidationSplitSuite ...@@ -140,6 +143,7 @@ class TrainValidationSplitSuite
.set(tvs.evaluator, evaluator) .set(tvs.evaluator, evaluator)
.set(tvs.trainRatio, 0.5) .set(tvs.trainRatio, 0.5)
.set(tvs.estimatorParamMaps, paramMaps) .set(tvs.estimatorParamMaps, paramMaps)
.set(tvs.seed, 42L)
val tvs2 = testDefaultReadWrite(tvs, testParams = false) val tvs2 = testDefaultReadWrite(tvs, testParams = false)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment