diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 861b1d4b66f2302cd6c693b00e300fa809056628..3d1d5b6892422b85a319de73482faf8a71dd1e0c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -539,13 +539,15 @@ class LogisticRegressionModel private[spark] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { - new BinaryLogisticRegressionSummary( - this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) + @Since("2.0.0") + def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + // Handle possible missing or invalid prediction columns + val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol() + new BinaryLogisticRegressionSummary(summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index b81c588e44fcc10b61502f2fe89692628c36505b..5ec02135cc53b7463a9b947c046e4fc187bd2b57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -412,15 +412,15 @@ class LinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + @Since("2.0.0") + def evaluate(dataset: DataFrame): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), this, Array(0D)) + $(labelCol), summaryModel, Array(0D)) } /**