diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 19cc323d5073f09dd9e8465bebf3fc888c8f1662..486043e8d97414cb717f51bec12aca670fd340c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.unpersist() val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) + val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol() val logRegSummary = new BinaryLogisticRegressionTrainingSummary( - model.transform(dataset), - $(probabilityCol), + summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol), objectiveHistory) @@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] ( new NullPointerException()) } + /** + * If the probability column is set returns the current model and probability column, + * otherwise generates a new column and sets it as the probability column on a new copy + * of the current model. + */ + private[classification] def findSummaryModelAndProbabilityCol(): + (LogisticRegressionModel, String) = { + $(probabilityCol) match { + case "" => + val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) + case p => (this, p) + } + } + private[classification] def setSummary( summary: LogisticRegressionTrainingSummary): this.type = { this.trainingSummary = Some(summary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a9a6ff8a783d547dba57a329d0504f8eac3133ef..1087afb0cdf79454a8a3a3d5e63b63a5e5a4e10f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -99,6 +99,17 @@ class LogisticRegressionSuite assert(model.hasParent) } + test("empty probabilityCol") { + val lr = new LogisticRegression().setProbabilityCol("") + val model = lr.fit(dataset) + assert(model.hasSummary) + // Validate that we re-insert a probability column for evaluation + val fieldNames = model.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + fieldNames.toSet)) + assert(fieldNames.exists(s => s.startsWith("probability_"))) + } + test("setThreshold, getThreshold") { val lr = new LogisticRegression // default