diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index cee2aa6e85523cf7de17ca791447501c9ac2c919..9208127eb1d797f4787e4f5811604f51de50da51 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -52,10 +52,12 @@ private[ml] trait CrossValidatorParams extends Params {
   def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)
 
   /**
-   * param for the evaluator for selection
+   * param for the evaluator used to select hyper-parameters that maximize the cross-validated
+   * metric
    * @group param
    */
-  val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
+  val evaluator: Param[Evaluator] = new Param(this, "evaluator",
+    "evaluator used to select hyper-parameters that maximize the cross-validated metric")
 
   /** @group getParam */
   def getEvaluator: Evaluator = $(evaluator)
@@ -120,6 +122,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
       trainingDataset.unpersist()
       var i = 0
       while (i < numModels) {
+        // TODO: duplicate evaluator to take extra params from input
         val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
         logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
         metrics(i) += metric
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 7b875e4b7125444a9dc4bae63633d2968d63bf49..c1b2077c985cfedb06e9508bf9f323cf7dc97218 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -22,7 +22,7 @@ from pyspark.ml.util import keyword_only
 from pyspark.mllib.common import inherit_doc
 
 
-__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator']
+__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator', 'Model']
 
 
 @inherit_doc
@@ -70,6 +70,15 @@ class Transformer(Params):
         raise NotImplementedError()
 
 
+@inherit_doc
+class Model(Transformer):
+    """
+    Abstract class for models that are fitted by estimators.
+    """
+
+    __metaclass__ = ABCMeta
+
+
 @inherit_doc
 class Pipeline(Estimator):
     """
@@ -154,7 +163,7 @@ class Pipeline(Estimator):
 
 
 @inherit_doc
-class PipelineModel(Transformer):
+class PipelineModel(Model):
     """
     Represents a compiled pipeline with transformers and fitted models.
     """
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 1773ab5bdcdb1ab92f1447075a6ffc895dfba043..f6cf2c3439ba5b816c5451877968baff448f08b7 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -16,8 +16,14 @@
 #
 
 import itertools
+import numpy as np
 
-__all__ = ['ParamGridBuilder']
+from pyspark.ml.param import Params, Param
+from pyspark.ml import Estimator, Model
+from pyspark.ml.util import keyword_only
+from pyspark.sql.functions import rand
+
+__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']
 
 
 class ParamGridBuilder(object):
@@ -79,6 +85,179 @@ class ParamGridBuilder(object):
         return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
 
 
+class CrossValidator(Estimator):
+    """
+    K-fold cross validation.
+
+    >>> from pyspark.ml.classification import LogisticRegression
+    >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
+    >>> from pyspark.mllib.linalg import Vectors
+    >>> dataset = sqlContext.createDataFrame(
+    ...     [(Vectors.dense([0.0, 1.0]), 0.0),
+    ...      (Vectors.dense([1.0, 2.0]), 1.0),
+    ...      (Vectors.dense([0.55, 3.0]), 0.0),
+    ...      (Vectors.dense([0.45, 4.0]), 1.0),
+    ...      (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
+    ...     ["features", "label"])
+    >>> lr = LogisticRegression()
+    >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
+    >>> evaluator = BinaryClassificationEvaluator()
+    >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+    >>> cvModel = cv.fit(dataset)
+    >>> expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
+    >>> cvModel.transform(dataset).collect() == expected.collect()
+    True
+    """
+
+    # a placeholder to make it appear in the generated doc
+    estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
+
+    # a placeholder to make it appear in the generated doc
+    estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
+
+    # a placeholder to make it appear in the generated doc
+    evaluator = Param(
+        Params._dummy(), "evaluator",
+        "evaluator used to select hyper-parameters that maximize the cross-validated metric")
+
+    # a placeholder to make it appear in the generated doc
+    numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
+
+    @keyword_only
+    def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+        """
+        __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
+        """
+        super(CrossValidator, self).__init__()
+        #: param for estimator to be cross-validated
+        self.estimator = Param(self, "estimator", "estimator to be cross-validated")
+        #: param for estimator param maps
+        self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps")
+        #: param for the evaluator used to select hyper-parameters that
+        #: maximize the cross-validated metric
+        self.evaluator = Param(
+            self, "evaluator",
+            "evaluator used to select hyper-parameters that maximize the cross-validated metric")
+        #: param for number of folds for cross validation
+        self.numFolds = Param(self, "numFolds", "number of folds for cross validation")
+        self._setDefault(numFolds=3)
+        kwargs = self.__init__._input_kwargs
+        self._set(**kwargs)
+
+    @keyword_only
+    def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+        """
+        setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+        Sets params for cross validator.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
+    def setEstimator(self, value):
+        """
+        Sets the value of :py:attr:`estimator`.
+        """
+        self.paramMap[self.estimator] = value
+        return self
+
+    def getEstimator(self):
+        """
+        Gets the value of estimator or its default value.
+        """
+        return self.getOrDefault(self.estimator)
+
+    def setEstimatorParamMaps(self, value):
+        """
+        Sets the value of :py:attr:`estimatorParamMaps`.
+        """
+        self.paramMap[self.estimatorParamMaps] = value
+        return self
+
+    def getEstimatorParamMaps(self):
+        """
+        Gets the value of estimatorParamMaps or its default value.
+        """
+        return self.getOrDefault(self.estimatorParamMaps)
+
+    def setEvaluator(self, value):
+        """
+        Sets the value of :py:attr:`evaluator`.
+        """
+        self.paramMap[self.evaluator] = value
+        return self
+
+    def getEvaluator(self):
+        """
+        Gets the value of evaluator or its default value.
+        """
+        return self.getOrDefault(self.evaluator)
+
+    def setNumFolds(self, value):
+        """
+        Sets the value of :py:attr:`numFolds`.
+        """
+        self.paramMap[self.numFolds] = value
+        return self
+
+    def getNumFolds(self):
+        """
+        Gets the value of numFolds or its default value.
+        """
+        return self.getOrDefault(self.numFolds)
+
+    def fit(self, dataset, params={}):
+        paramMap = self.extractParamMap(params)
+        est = paramMap[self.estimator]
+        epm = paramMap[self.estimatorParamMaps]
+        numModels = len(epm)
+        eva = paramMap[self.evaluator]
+        nFolds = paramMap[self.numFolds]
+        h = 1.0 / nFolds
+        randCol = self.uid + "_rand"
+        df = dataset.select("*", rand(0).alias(randCol))
+        metrics = np.zeros(numModels)
+        for i in range(nFolds):
+            validateLB = i * h
+            validateUB = (i + 1) * h
+            condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
+            validation = df.filter(condition)
+            train = df.filter(~condition)
+            for j in range(numModels):
+                model = est.fit(train, epm[j])
+                # TODO: duplicate evaluator to take extra params from input
+                metric = eva.evaluate(model.transform(validation, epm[j]))
+                metrics[j] += metric
+        bestIndex = np.argmax(metrics)
+        bestModel = est.fit(dataset, epm[bestIndex])
+        return CrossValidatorModel(bestModel)
+
+
+class CrossValidatorModel(Model):
+    """
+    Model from k-fold cross validation.
+    """
+
+    def __init__(self, bestModel):
+        #: best model from cross validation
+        self.bestModel = bestModel
+
+    def transform(self, dataset, params={}):
+        return self.bestModel.transform(dataset, params)
+
+
 if __name__ == "__main__":
     import doctest
-    doctest.testmod()
+    from pyspark.context import SparkContext
+    from pyspark.sql import SQLContext
+    globs = globals().copy()
+    # The small batch size here ensures that we see multiple batches,
+    # even in these small test examples:
+    sc = SparkContext("local[2]", "ml.tuning tests")
+    sqlContext = SQLContext(sc)
+    globs['sc'] = sc
+    globs['sqlContext'] = sqlContext
+    (failure_count, test_count) = doctest.testmod(
+        globs=globs, optionflags=doctest.ELLIPSIS)
+    sc.stop()
+    if failure_count:
+        exit(-1)
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 73741c4b40dfbbdf2abfd67aba1d567ac8031456..0634254bbd5cf3092c02611f6f9e85945881ed2b 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -20,7 +20,7 @@ from abc import ABCMeta
 from pyspark import SparkContext
 from pyspark.sql import DataFrame
 from pyspark.ml.param import Params
-from pyspark.ml.pipeline import Estimator, Transformer, Evaluator
+from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
 from pyspark.mllib.common import inherit_doc
 
 
@@ -133,7 +133,7 @@ class JavaTransformer(Transformer, JavaWrapper):
 
 
 @inherit_doc
-class JavaModel(JavaTransformer):
+class JavaModel(Model, JavaTransformer):
     """
     Base class for :py:class:`Model`s that wrap Java/Scala
     implementations.