diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 4d1d6364d7ad872adf4e0d9892e20bf8da9f051c..07330bb6b0fde20922d81e653c997f8c3251eb66 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.shared.HasSeed
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
@@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType
 /**
  * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
  */
-private[ml] trait TrainValidationSplitParams extends ValidatorParams {
+private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed {
   /**
    * Param for ratio between train and validation data. Must be between 0 and 1.
    * Default: 0.75
@@ -80,6 +81,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
   @Since("1.5.0")
   def setTrainRatio(value: Double): this.type = set(trainRatio, value)
 
+  /** @group setParam */
+  @Since("2.0.0")
+  def setSeed(value: Long): this.type = set(seed, value)
+
   @Since("1.5.0")
   override def fit(dataset: DataFrame): TrainValidationSplitModel = {
     val schema = dataset.schema
@@ -91,10 +96,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
     val numModels = epm.length
     val metrics = new Array[Double](epm.length)
 
-    val Array(training, validation) =
-      dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
-    val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
-    val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
+    val Array(trainingDataset, validationDataset) =
+      dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
+    trainingDataset.cache()
+    validationDataset.cache()
 
     // multi-model training
     logDebug(s"Train split with multiple sets of parameters.")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 7cf7b3e087590abf4a92a30caae1af2fdad5f673..4030956fabea3564f16bcf02b5c052b48aaf7679 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -48,6 +48,7 @@ class TrainValidationSplitSuite
       .setEstimatorParamMaps(lrParamMaps)
       .setEvaluator(eval)
       .setTrainRatio(0.5)
+      .setSeed(42L)
     val cvModel = cv.fit(dataset)
     val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
     assert(cv.getTrainRatio === 0.5)
@@ -72,6 +73,7 @@ class TrainValidationSplitSuite
       .setEstimatorParamMaps(lrParamMaps)
       .setEvaluator(eval)
       .setTrainRatio(0.5)
+      .setSeed(42L)
     val cvModel = cv.fit(dataset)
     val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
     assert(parent.getRegParam === 0.001)
@@ -120,6 +122,7 @@ class TrainValidationSplitSuite
       .setEvaluator(evaluator)
       .setTrainRatio(0.5)
       .setEstimatorParamMaps(paramMaps)
+      .setSeed(42L)
 
     val tvs2 = testDefaultReadWrite(tvs, testParams = false)
 
@@ -140,6 +143,7 @@ class TrainValidationSplitSuite
       .set(tvs.evaluator, evaluator)
       .set(tvs.trainRatio, 0.5)
       .set(tvs.estimatorParamMaps, paramMaps)
+      .set(tvs.seed, 42L)
 
     val tvs2 = testDefaultReadWrite(tvs, testParams = false)