Skip to content
Snippets Groups Projects
Commit c633ed32 authored by Yu ISHIKAWA's avatar Yu ISHIKAWA Committed by Xiangrui Meng
Browse files

[SPARK-10284] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.tuning

Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>

Closes #8694 from yu-iskw/SPARK-10284.
parent 69c9830d
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import itertools import itertools
import numpy as np import numpy as np
from pyspark import since
from pyspark.ml.param import Params, Param from pyspark.ml.param import Params, Param
from pyspark.ml import Estimator, Model from pyspark.ml import Estimator, Model
from pyspark.ml.util import keyword_only from pyspark.ml.util import keyword_only
...@@ -47,11 +48,14 @@ class ParamGridBuilder(object): ...@@ -47,11 +48,14 @@ class ParamGridBuilder(object):
True True
>>> all([m in expected for m in output]) >>> all([m in expected for m in output])
True True
.. versionadded:: 1.4.0
""" """
def __init__(self): def __init__(self):
self._param_grid = {} self._param_grid = {}
@since("1.4.0")
def addGrid(self, param, values): def addGrid(self, param, values):
""" """
Sets the given parameters in this grid to fixed values. Sets the given parameters in this grid to fixed values.
...@@ -60,6 +64,7 @@ class ParamGridBuilder(object): ...@@ -60,6 +64,7 @@ class ParamGridBuilder(object):
return self return self
@since("1.4.0")
def baseOn(self, *args): def baseOn(self, *args):
""" """
Sets the given parameters in this grid to fixed values. Sets the given parameters in this grid to fixed values.
...@@ -73,6 +78,7 @@ class ParamGridBuilder(object): ...@@ -73,6 +78,7 @@ class ParamGridBuilder(object):
return self return self
@since("1.4.0")
def build(self): def build(self):
""" """
Builds and returns all combinations of parameters specified Builds and returns all combinations of parameters specified
...@@ -104,6 +110,8 @@ class CrossValidator(Estimator): ...@@ -104,6 +110,8 @@ class CrossValidator(Estimator):
>>> cvModel = cv.fit(dataset) >>> cvModel = cv.fit(dataset)
>>> evaluator.evaluate(cvModel.transform(dataset)) >>> evaluator.evaluate(cvModel.transform(dataset))
0.8333... 0.8333...
.. versionadded:: 1.4.0
""" """
# a placeholder to make it appear in the generated doc # a placeholder to make it appear in the generated doc
...@@ -142,6 +150,7 @@ class CrossValidator(Estimator): ...@@ -142,6 +150,7 @@ class CrossValidator(Estimator):
self._set(**kwargs) self._set(**kwargs)
@keyword_only @keyword_only
@since("1.4.0")
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
""" """
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
...@@ -150,6 +159,7 @@ class CrossValidator(Estimator): ...@@ -150,6 +159,7 @@ class CrossValidator(Estimator):
kwargs = self.setParams._input_kwargs kwargs = self.setParams._input_kwargs
return self._set(**kwargs) return self._set(**kwargs)
@since("1.4.0")
def setEstimator(self, value): def setEstimator(self, value):
""" """
Sets the value of :py:attr:`estimator`. Sets the value of :py:attr:`estimator`.
...@@ -157,12 +167,14 @@ class CrossValidator(Estimator): ...@@ -157,12 +167,14 @@ class CrossValidator(Estimator):
self._paramMap[self.estimator] = value self._paramMap[self.estimator] = value
return self return self
@since("1.4.0")
def getEstimator(self): def getEstimator(self):
""" """
Gets the value of estimator or its default value. Gets the value of estimator or its default value.
""" """
return self.getOrDefault(self.estimator) return self.getOrDefault(self.estimator)
@since("1.4.0")
def setEstimatorParamMaps(self, value): def setEstimatorParamMaps(self, value):
""" """
Sets the value of :py:attr:`estimatorParamMaps`. Sets the value of :py:attr:`estimatorParamMaps`.
...@@ -170,12 +182,14 @@ class CrossValidator(Estimator): ...@@ -170,12 +182,14 @@ class CrossValidator(Estimator):
self._paramMap[self.estimatorParamMaps] = value self._paramMap[self.estimatorParamMaps] = value
return self return self
@since("1.4.0")
def getEstimatorParamMaps(self): def getEstimatorParamMaps(self):
""" """
Gets the value of estimatorParamMaps or its default value. Gets the value of estimatorParamMaps or its default value.
""" """
return self.getOrDefault(self.estimatorParamMaps) return self.getOrDefault(self.estimatorParamMaps)
@since("1.4.0")
def setEvaluator(self, value): def setEvaluator(self, value):
""" """
Sets the value of :py:attr:`evaluator`. Sets the value of :py:attr:`evaluator`.
...@@ -183,12 +197,14 @@ class CrossValidator(Estimator): ...@@ -183,12 +197,14 @@ class CrossValidator(Estimator):
self._paramMap[self.evaluator] = value self._paramMap[self.evaluator] = value
return self return self
@since("1.4.0")
def getEvaluator(self): def getEvaluator(self):
""" """
Gets the value of evaluator or its default value. Gets the value of evaluator or its default value.
""" """
return self.getOrDefault(self.evaluator) return self.getOrDefault(self.evaluator)
@since("1.4.0")
def setNumFolds(self, value): def setNumFolds(self, value):
""" """
Sets the value of :py:attr:`numFolds`. Sets the value of :py:attr:`numFolds`.
...@@ -196,6 +212,7 @@ class CrossValidator(Estimator): ...@@ -196,6 +212,7 @@ class CrossValidator(Estimator):
self._paramMap[self.numFolds] = value self._paramMap[self.numFolds] = value
return self return self
@since("1.4.0")
def getNumFolds(self): def getNumFolds(self):
""" """
Gets the value of numFolds or its default value. Gets the value of numFolds or its default value.
...@@ -231,7 +248,15 @@ class CrossValidator(Estimator): ...@@ -231,7 +248,15 @@ class CrossValidator(Estimator):
bestModel = est.fit(dataset, epm[bestIndex]) bestModel = est.fit(dataset, epm[bestIndex])
return CrossValidatorModel(bestModel) return CrossValidatorModel(bestModel)
@since("1.4.0")
def copy(self, extra=None): def copy(self, extra=None):
"""
Creates a copy of this instance with a randomly generated uid
and some extra params. This copies creates a deep copy of
the embedded paramMap, and copies the embedded and extra parameters over.
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
"""
if extra is None: if extra is None:
extra = dict() extra = dict()
newCV = Params.copy(self, extra) newCV = Params.copy(self, extra)
...@@ -246,6 +271,8 @@ class CrossValidator(Estimator): ...@@ -246,6 +271,8 @@ class CrossValidator(Estimator):
class CrossValidatorModel(Model): class CrossValidatorModel(Model):
""" """
Model from k-fold cross validation. Model from k-fold cross validation.
.. versionadded:: 1.4.0
""" """
def __init__(self, bestModel): def __init__(self, bestModel):
...@@ -256,6 +283,7 @@ class CrossValidatorModel(Model): ...@@ -256,6 +283,7 @@ class CrossValidatorModel(Model):
def _transform(self, dataset): def _transform(self, dataset):
return self.bestModel.transform(dataset) return self.bestModel.transform(dataset)
@since("1.4.0")
def copy(self, extra=None): def copy(self, extra=None):
""" """
Creates a copy of this instance with a randomly generated uid Creates a copy of this instance with a randomly generated uid
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment