diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index e2444ab65b43b6fdb968893c947a7b32cad70224..f979319cc4b58bfc904c7a0d884d124022bccfbb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -32,38 +32,7 @@ import org.apache.spark.sql.types.StructType
 /**
  * Params for [[CrossValidator]] and [[CrossValidatorModel]].
  */
-private[ml] trait CrossValidatorParams extends Params {
-
-  /**
-   * param for the estimator to be cross-validated
-   * @group param
-   */
-  val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
-
-  /** @group getParam */
-  def getEstimator: Estimator[_] = $(estimator)
-
-  /**
-   * param for estimator param maps
-   * @group param
-   */
-  val estimatorParamMaps: Param[Array[ParamMap]] =
-    new Param(this, "estimatorParamMaps", "param maps for the estimator")
-
-  /** @group getParam */
-  def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
-
-  /**
-   * param for the evaluator used to select hyper-parameters that maximize the cross-validated
-   * metric
-   * @group param
-   */
-  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
-    "evaluator used to select hyper-parameters that maximize the cross-validated metric")
-
-  /** @group getParam */
-  def getEvaluator: Evaluator = $(evaluator)
-
+private[ml] trait CrossValidatorParams extends ValidatorParams {
   /**
    * Param for number of folds for cross validation.  Must be >= 2.
    * Default: 3
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c0edc730b6fd67d355da5ebad4f82121b79211df
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.evaluation.Evaluator
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
+ */
+private[ml] trait TrainValidationSplitParams extends ValidatorParams {
+  /**
+   * Param for ratio between train and validation data. Must be between 0 and 1.
+   * Default: 0.75
+   * @group param
+   */
+  val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
+    "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1))
+
+  /** @group getParam */
+  def getTrainRatio: Double = $(trainRatio)
+
+  setDefault(trainRatio -> 0.75)
+}
+
+/**
+ * :: Experimental ::
+ * Validation for hyper-parameter tuning.
+ * Randomly splits the input dataset into train and validation sets,
+ * and uses evaluation metric on the validation set to select the best model.
+ * Similar to [[CrossValidator]], but only splits the set once.
+ */
+@Experimental
+class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel]
+  with TrainValidationSplitParams with Logging {
+
+  def this() = this(Identifiable.randomUID("tvs"))
+
+  /** @group setParam */
+  def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
+
+  /** @group setParam */
+  def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
+
+  /** @group setParam */
+  def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
+
+  /** @group setParam */
+  def setTrainRatio(value: Double): this.type = set(trainRatio, value)
+
+  override def fit(dataset: DataFrame): TrainValidationSplitModel = {
+    val schema = dataset.schema
+    transformSchema(schema, logging = true)
+    val sqlCtx = dataset.sqlContext
+    val est = $(estimator)
+    val eval = $(evaluator)
+    val epm = $(estimatorParamMaps)
+    val numModels = epm.length
+    val metrics = new Array[Double](epm.length)
+
+    val Array(training, validation) =
+      dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
+    val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
+    val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
+
+    // multi-model training
+    logDebug(s"Train split with multiple sets of parameters.")
+    val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
+    trainingDataset.unpersist()
+    var i = 0
+    while (i < numModels) {
+      // TODO: duplicate evaluator to take extra params from input
+      val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
+      logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
+      metrics(i) += metric
+      i += 1
+    }
+    validationDataset.unpersist()
+
+    logInfo(s"Train validation split metrics: ${metrics.toSeq}")
+    val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
+    logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+    logInfo(s"Best train validation split metric: $bestMetric.")
+    val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
+    copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    $(estimator).transformSchema(schema)
+  }
+
+  override def validateParams(): Unit = {
+    super.validateParams()
+    val est = $(estimator)
+    for (paramMap <- $(estimatorParamMaps)) {
+      est.copy(paramMap).validateParams()
+    }
+  }
+
+  override def copy(extra: ParamMap): TrainValidationSplit = {
+    val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit]
+    if (copied.isDefined(estimator)) {
+      copied.setEstimator(copied.getEstimator.copy(extra))
+    }
+    if (copied.isDefined(evaluator)) {
+      copied.setEvaluator(copied.getEvaluator.copy(extra))
+    }
+    copied
+  }
+}
+
+/**
+ * :: Experimental ::
+ * Model from train validation split.
+ *
+ * @param uid Id.
+ * @param bestModel Estimator determined best model.
+ * @param validationMetrics Evaluated validation metrics.
+ */
+@Experimental
+class TrainValidationSplitModel private[ml] (
+    override val uid: String,
+    val bestModel: Model[_],
+    val validationMetrics: Array[Double])
+  extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
+
+  override def validateParams(): Unit = {
+    bestModel.validateParams()
+  }
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
+    bestModel.transform(dataset)
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    bestModel.transformSchema(schema)
+  }
+
+  override def copy(extra: ParamMap): TrainValidationSplitModel = {
+    val copied = new TrainValidationSplitModel (
+      uid,
+      bestModel.copy(extra).asInstanceOf[Model[_]],
+      validationMetrics.clone())
+    copyValues(copied, extra)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
new file mode 100644
index 0000000000000000000000000000000000000000..8897ab0825acd2425368f3896a6c1317464af7e9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.Estimator
+import org.apache.spark.ml.evaluation.Evaluator
+import org.apache.spark.ml.param.{ParamMap, Param, Params}
+
+/**
+ * :: DeveloperApi ::
+ * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
+ */
+@DeveloperApi
+private[ml] trait ValidatorParams extends Params {
+
+  /**
+   * param for the estimator to be validated
+   * @group param
+   */
+  val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
+
+  /** @group getParam */
+  def getEstimator: Estimator[_] = $(estimator)
+
+  /**
+   * param for estimator param maps
+   * @group param
+   */
+  val estimatorParamMaps: Param[Array[ParamMap]] =
+    new Param(this, "estimatorParamMaps", "param maps for the estimator")
+
+  /** @group getParam */
+  def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
+
+  /**
+   * param for the evaluator used to select hyper-parameters that maximize the validated metric
+   * @group param
+   */
+  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
+    "evaluator used to select hyper-parameters that maximize the validated metric")
+
+  /** @group getParam */
+  def getEvaluator: Evaluator = $(evaluator)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c8e58f216cceb36901025ecf82a92e01800a7066
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasInputCol
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
+
+class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext {
+  test("train validation with logistic regression") {
+    val dataset = sqlContext.createDataFrame(
+      sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
+
+    val lr = new LogisticRegression
+    val lrParamMaps = new ParamGridBuilder()
+      .addGrid(lr.regParam, Array(0.001, 1000.0))
+      .addGrid(lr.maxIter, Array(0, 10))
+      .build()
+    val eval = new BinaryClassificationEvaluator
+    val cv = new TrainValidationSplit()
+      .setEstimator(lr)
+      .setEstimatorParamMaps(lrParamMaps)
+      .setEvaluator(eval)
+      .setTrainRatio(0.5)
+    val cvModel = cv.fit(dataset)
+    val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
+    assert(cv.getTrainRatio === 0.5)
+    assert(parent.getRegParam === 0.001)
+    assert(parent.getMaxIter === 10)
+    assert(cvModel.validationMetrics.length === lrParamMaps.length)
+  }
+
+  test("train validation with linear regression") {
+    val dataset = sqlContext.createDataFrame(
+        sc.parallelize(LinearDataGenerator.generateLinearInput(
+            6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
+
+    val trainer = new LinearRegression
+    val lrParamMaps = new ParamGridBuilder()
+      .addGrid(trainer.regParam, Array(1000.0, 0.001))
+      .addGrid(trainer.maxIter, Array(0, 10))
+      .build()
+    val eval = new RegressionEvaluator()
+    val cv = new TrainValidationSplit()
+      .setEstimator(trainer)
+      .setEstimatorParamMaps(lrParamMaps)
+      .setEvaluator(eval)
+      .setTrainRatio(0.5)
+    val cvModel = cv.fit(dataset)
+    val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
+    assert(parent.getRegParam === 0.001)
+    assert(parent.getMaxIter === 10)
+    assert(cvModel.validationMetrics.length === lrParamMaps.length)
+
+      eval.setMetricName("r2")
+    val cvModel2 = cv.fit(dataset)
+    val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
+    assert(parent2.getRegParam === 0.001)
+    assert(parent2.getMaxIter === 10)
+    assert(cvModel2.validationMetrics.length === lrParamMaps.length)
+  }
+
+  test("validateParams should check estimatorParamMaps") {
+    import TrainValidationSplitSuite._
+
+    val est = new MyEstimator("est")
+    val eval = new MyEvaluator
+    val paramMaps = new ParamGridBuilder()
+      .addGrid(est.inputCol, Array("input1", "input2"))
+      .build()
+
+    val cv = new TrainValidationSplit()
+      .setEstimator(est)
+      .setEstimatorParamMaps(paramMaps)
+      .setEvaluator(eval)
+      .setTrainRatio(0.5)
+    cv.validateParams() // This should pass.
+
+    val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "")
+    cv.setEstimatorParamMaps(invalidParamMaps)
+    intercept[IllegalArgumentException] {
+      cv.validateParams()
+    }
+  }
+}
+
+object TrainValidationSplitSuite {
+
+  abstract class MyModel extends Model[MyModel]
+
+  class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
+
+    override def validateParams(): Unit = require($(inputCol).nonEmpty)
+
+    override def fit(dataset: DataFrame): MyModel = {
+      throw new UnsupportedOperationException
+    }
+
+    override def transformSchema(schema: StructType): StructType = {
+      throw new UnsupportedOperationException
+    }
+
+    override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra)
+  }
+
+  class MyEvaluator extends Evaluator {
+
+    override def evaluate(dataset: DataFrame): Double = {
+      throw new UnsupportedOperationException
+    }
+
+    override val uid: String = "eval"
+
+    override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
+  }
+}