diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index c1b2077c985cfedb06e9508bf9f323cf7dc97218..fdbae06405f6a0345c3031b3d5541974ac0ada7f 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -179,7 +179,7 @@ class PipelineModel(Model): return dataset -class Evaluator(object): +class Evaluator(Params): """ Base class for evaluators that compute metrics from predictions. """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 3a42bcf72389496ba985b2fd831c72256bdcb966..75bb5d749ca87c586621812214a67ee93e58a9bf 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -34,7 +34,7 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame from pyspark.ml.param import Param from pyspark.ml.param.shared import HasMaxIter, HasInputCol -from pyspark.ml.pipeline import Transformer, Estimator, Pipeline +from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer class MockDataset(DataFrame): @@ -77,7 +77,7 @@ class MockEstimator(Estimator): return model -class MockModel(MockTransformer, Transformer): +class MockModel(MockTransformer, Model): def __init__(self): super(MockModel, self).__init__() diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 28e3727f2c064bd00803b4b6ebd2860142e3d756..86f4dc7368be07df883f5dff6e2a539739eeec70 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -236,6 +236,7 @@ class CrossValidatorModel(Model): """ def __init__(self, bestModel): + super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel