From 518ab5101073ee35d62e33c8f7281a1e6342101e Mon Sep 17 00:00:00 2001
From: Holden Karau <holden@pigscanfly.ca>
Date: Fri, 11 Dec 2015 02:35:53 -0500
Subject: [PATCH] [SPARK-10991][ML] logistic regression training summary handle
 empty prediction col

LogisticRegression training summary should still function if the predictionCol is set to an empty string or otherwise unset (related too https://issues.apache.org/jira/browse/SPARK-9718 )

Author: Holden Karau <holden@pigscanfly.ca>
Author: Holden Karau <holden@us.ibm.com>

Closes #9037 from holdenk/SPARK-10991-LogisticRegressionTrainingSummary-handle-empty-prediction-col.
---
 .../classification/LogisticRegression.scala   | 20 +++++++++++++++++--
 .../LogisticRegressionSuite.scala             | 11 ++++++++++
 2 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 19cc323d50..486043e8d9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") (
     if (handlePersistence) instances.unpersist()
 
     val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept))
+    val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
     val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
-      model.transform(dataset),
-      $(probabilityCol),
+      summaryModel.transform(dataset),
+      probabilityColName,
       $(labelCol),
       $(featuresCol),
       objectiveHistory)
@@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] (
         new NullPointerException())
   }
 
+  /**
+   * If the probability column is set returns the current model and probability column,
+   * otherwise generates a new column and sets it as the probability column on a new copy
+   * of the current model.
+   */
+  private[classification] def findSummaryModelAndProbabilityCol():
+      (LogisticRegressionModel, String) = {
+    $(probabilityCol) match {
+      case "" =>
+        val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString()
+        (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
+      case p => (this, p)
+    }
+  }
+
   private[classification] def setSummary(
       summary: LogisticRegressionTrainingSummary): this.type = {
     this.trainingSummary = Some(summary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index a9a6ff8a78..1087afb0cd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -99,6 +99,17 @@ class LogisticRegressionSuite
     assert(model.hasParent)
   }
 
+  test("empty probabilityCol") {
+    val lr = new LogisticRegression().setProbabilityCol("")
+    val model = lr.fit(dataset)
+    assert(model.hasSummary)
+    // Validate that we re-insert a probability column for evaluation
+    val fieldNames = model.summary.predictions.schema.fieldNames
+    assert((dataset.schema.fieldNames.toSet).subsetOf(
+      fieldNames.toSet))
+    assert(fieldNames.exists(s => s.startsWith("probability_")))
+  }
+
   test("setThreshold, getThreshold") {
     val lr = new LogisticRegression
     // default
-- 
GitLab