diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 4d1d6364d7ad872adf4e0d9892e20bf8da9f051c..07330bb6b0fde20922d81e653c997f8c3251eb66 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. */ -private[ml] trait TrainValidationSplitParams extends ValidatorParams { +private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed { /** * Param for ratio between train and validation data. Must be between 0 and 1. * Default: 0.75 @@ -80,6 +81,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + @Since("1.5.0") override def fit(dataset: DataFrame): TrainValidationSplitModel = { val schema = dataset.schema @@ -91,10 +96,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val numModels = epm.length val metrics = new Array[Double](epm.length) - val Array(training, validation) = - dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) - val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() - val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() + val Array(trainingDataset, validationDataset) = + dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) + trainingDataset.cache() + validationDataset.cache() // multi-model training logDebug(s"Train split with multiple sets of parameters.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 7cf7b3e087590abf4a92a30caae1af2fdad5f673..4030956fabea3564f16bcf02b5c052b48aaf7679 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -48,6 +48,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(cv.getTrainRatio === 0.5) @@ -72,6 +73,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) @@ -120,6 +122,7 @@ class TrainValidationSplitSuite .setEvaluator(evaluator) .setTrainRatio(0.5) .setEstimatorParamMaps(paramMaps) + .setSeed(42L) val tvs2 = testDefaultReadWrite(tvs, testParams = false) @@ -140,6 +143,7 @@ class TrainValidationSplitSuite .set(tvs.evaluator, evaluator) .set(tvs.trainRatio, 0.5) .set(tvs.estimatorParamMaps, paramMaps) + .set(tvs.seed, 42L) val tvs2 = testDefaultReadWrite(tvs, testParams = false)