Skip to content
Snippets Groups Projects
Commit 3a44aebd authored by Martin Menestret's avatar Martin Menestret Committed by Joseph K. Bradley
Browse files

[SPARK-9690][ML][PYTHON] pyspark CrossValidator random seed

Extend CrossValidator with HasSeed in PySpark.

This PR replaces [https://github.com/apache/spark/pull/7997]

CC: yanboliang thunterdb mmenestret  Would one of you mind taking a look?  Thanks!

Author: Joseph K. Bradley <joseph@databricks.com>
Author: Martin MENESTRET <mmenestret@ippon.fr>

Closes #10268 from jkbradley/pyspark-cv-seed.
parent 9657ee87
No related branches found
No related tags found
No related merge requests found
......@@ -19,8 +19,9 @@ import itertools
import numpy as np
from pyspark import since
from pyspark.ml.param import Params, Param
from pyspark.ml import Estimator, Model
from pyspark.ml.param import Params, Param
from pyspark.ml.param.shared import HasSeed
from pyspark.ml.util import keyword_only
from pyspark.sql.functions import rand
......@@ -89,7 +90,7 @@ class ParamGridBuilder(object):
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
class CrossValidator(Estimator):
class CrossValidator(Estimator, HasSeed):
"""
K-fold cross validation.
......@@ -129,9 +130,11 @@ class CrossValidator(Estimator):
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
seed=None):
"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
seed=None)
"""
super(CrossValidator, self).__init__()
#: param for estimator to be cross-validated
......@@ -151,9 +154,11 @@ class CrossValidator(Estimator):
@keyword_only
@since("1.4.0")
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
seed=None):
"""
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
seed=None):
Sets params for cross validator.
"""
kwargs = self.setParams._input_kwargs
......@@ -225,9 +230,10 @@ class CrossValidator(Estimator):
numModels = len(epm)
eva = self.getOrDefault(self.evaluator)
nFolds = self.getOrDefault(self.numFolds)
seed = self.getOrDefault(self.seed)
h = 1.0 / nFolds
randCol = self.uid + "_rand"
df = dataset.select("*", rand(0).alias(randCol))
df = dataset.select("*", rand(seed).alias(randCol))
metrics = np.zeros(numModels)
for i in range(nFolds):
validateLB = i * h
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment