From 96028e36b4d08427fdd94df55595849c2346ead4 Mon Sep 17 00:00:00 2001
From: WeichenXu <weichen.xu@databricks.com>
Date: Thu, 31 Aug 2017 16:22:40 -0700
Subject: [PATCH] [SPARK-17139][ML][FOLLOW-UP] Add convenient method `asBinary`
 for casting to BinaryLogisticRegressionSummary

## What changes were proposed in this pull request?

add an "asBinary" method to LogisticRegressionSummary for convenient casting to BinaryLogisticRegressionSummary.

## How was this patch tested?

Testcase updated.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19072 from WeichenXu123/mlor_summary_as_binary.
---
 .../spark/ml/classification/LogisticRegression.scala  | 11 +++++++++++
 .../ml/classification/LogisticRegressionSuite.scala   |  6 ++++++
 project/MimaExcludes.scala                            |  1 +
 3 files changed, 18 insertions(+)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 1869d51af7..f491a679b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1473,6 +1473,17 @@ sealed trait LogisticRegressionSummary extends Serializable {
   /** Returns weighted averaged f1-measure. */
   @Since("2.3.0")
   def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0)
+
+  /**
+   * Convenient method for casting to binary logistic regression summary.
+   * This method will throws an Exception if the summary is not a binary summary.
+   */
+  @Since("2.3.0")
+  def asBinary: BinaryLogisticRegressionSummary = this match {
+    case b: BinaryLogisticRegressionSummary => b
+    case _ =>
+      throw new RuntimeException("Cannot cast to a binary summary.")
+  }
 }
 
 /**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 6649fa4025..6bf1253b71 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -256,6 +256,7 @@ class LogisticRegressionSuite
 
     val blorModel = lr.fit(smallBinaryDataset)
     assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+    assert(blorModel.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary])
     assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
 
     val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
@@ -265,6 +266,11 @@ class LogisticRegressionSuite
         mlorModel.binarySummary
       }
     }
+    withClue("cannot cast summary to binary summary multiclass model") {
+      intercept[RuntimeException] {
+        mlorModel.summary.asBinary
+      }
+    }
 
     val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset)
     assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index eecda26abb..27e4183550 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -62,6 +62,7 @@ object MimaExcludes {
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"),
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"),
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.asBinary"),
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics"),
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$_setter_$org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics_=")
   )
-- 
GitLab