From 96028e36b4d08427fdd94df55595849c2346ead4 Mon Sep 17 00:00:00 2001 From: WeichenXu <weichen.xu@databricks.com> Date: Thu, 31 Aug 2017 16:22:40 -0700 Subject: [PATCH] [SPARK-17139][ML][FOLLOW-UP] Add convenient method `asBinary` for casting to BinaryLogisticRegressionSummary ## What changes were proposed in this pull request? add an "asBinary" method to LogisticRegressionSummary for convenient casting to BinaryLogisticRegressionSummary. ## How was this patch tested? Testcase updated. Author: WeichenXu <weichen.xu@databricks.com> Closes #19072 from WeichenXu123/mlor_summary_as_binary. --- .../spark/ml/classification/LogisticRegression.scala | 11 +++++++++++ .../ml/classification/LogisticRegressionSuite.scala | 6 ++++++ project/MimaExcludes.scala | 1 + 3 files changed, 18 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 1869d51af7..f491a679b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1473,6 +1473,17 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Returns weighted averaged f1-measure. */ @Since("2.3.0") def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0) + + /** + * Convenient method for casting to binary logistic regression summary. + * This method will throws an Exception if the summary is not a binary summary. + */ + @Since("2.3.0") + def asBinary: BinaryLogisticRegressionSummary = this match { + case b: BinaryLogisticRegressionSummary => b + case _ => + throw new RuntimeException("Cannot cast to a binary summary.") + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 6649fa4025..6bf1253b71 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -256,6 +256,7 @@ class LogisticRegressionSuite val blorModel = lr.fit(smallBinaryDataset) assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + assert(blorModel.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary]) assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset) @@ -265,6 +266,11 @@ class LogisticRegressionSuite mlorModel.binarySummary } } + withClue("cannot cast summary to binary summary multiclass model") { + intercept[RuntimeException] { + mlorModel.summary.asBinary + } + } val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset) assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eecda26abb..27e4183550 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,6 +62,7 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.asBinary"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$_setter_$org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics_=") ) -- GitLab