diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ffeb4459e1aac59f5e89803a996e4a79185fd1e7..b64858214d20db3cc7b95c14f803681601c1ef09 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,14 +18,11 @@ import itertools import numpy as np -from pyspark import SparkContext from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed -from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand -from pyspark.ml.common import inherit_doc, _py2java __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', 'TrainValidationSplitModel'] @@ -232,8 +229,9 @@ class CrossValidator(Estimator, ValidatorParams): condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) validation = df.filter(condition) train = df.filter(~condition) + models = est.fit(train, epm) for j in range(numModels): - model = est.fit(train, epm[j]) + model = models[j] # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric/nFolds @@ -388,8 +386,9 @@ class TrainValidationSplit(Estimator, ValidatorParams): condition = (df[randCol] >= tRatio) validation = df.filter(condition) train = df.filter(~condition) + models = est.fit(train, epm) for j in range(numModels): - model = est.fit(train, epm[j]) + model = models[j] metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric if eva.isLargerBetter():