Skip to content
Snippets Groups Projects
Commit bd97840d authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7432] [MLLIB] fix flaky CrossValidator doctest

The new test uses CV to compare `maxIter=0` and `maxIter=1`, and validate on the evaluation result. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #6572 from mengxr/SPARK-7432 and squashes the following commits:

c236bb8 [Xiangrui Meng] fix flacky cv doctest
parent 445647a1
No related branches found
No related tags found
No related merge requests found
......@@ -91,20 +91,19 @@ class CrossValidator(Estimator):
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.mllib.linalg import Vectors
>>> dataset = sqlContext.createDataFrame(
... [(Vectors.dense([0.0, 1.0]), 0.0),
... (Vectors.dense([1.0, 2.0]), 1.0),
... (Vectors.dense([0.55, 3.0]), 0.0),
... (Vectors.dense([0.45, 4.0]), 1.0),
... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
... [(Vectors.dense([0.0]), 0.0),
... (Vectors.dense([0.4]), 1.0),
... (Vectors.dense([0.5]), 0.0),
... (Vectors.dense([0.6]), 1.0),
... (Vectors.dense([1.0]), 1.0)] * 10,
... ["features", "label"])
>>> lr = LogisticRegression()
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
>>> # SPARK-7432: The following test is flaky.
>>> # cvModel = cv.fit(dataset)
>>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
>>> # cvModel.transform(dataset).collect() == expected.collect()
>>> cvModel = cv.fit(dataset)
>>> evaluator.evaluate(cvModel.transform(dataset))
0.8333...
"""
# a placeholder to make it appear in the generated doc
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment