From 2f739567080d804a942cfcca0e22f91ab7cbea36 Mon Sep 17 00:00:00 2001
From: Bryan Cutler <cutlerb@gmail.com>
Date: Thu, 29 Sep 2016 16:31:30 -0700
Subject: [PATCH] [SPARK-17697][ML] Fixed bug in summary calculations that
 pattern match against label without casting

## What changes were proposed in this pull request?
In calling LogisticRegression.evaluate and GeneralizedLinearRegression.evaluate using a Dataset where the Label is not of a double type, calculations pattern match against a double and throw a MatchError.  This fix casts the Label column to a DoubleType to ensure there is no MatchError.

## How was this patch tested?
Added unit tests to call evaluate with a dataset that has Label as other numeric types.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #15288 from BryanCutler/binaryLOR-numericCheck-SPARK-17697.
---
 .../classification/LogisticRegression.scala   |  2 +-
 .../GeneralizedLinearRegression.scala         | 11 ++++----
 .../LogisticRegressionSuite.scala             | 18 ++++++++++++-
 .../GeneralizedLinearRegressionSuite.scala    | 25 +++++++++++++++++++
 4 files changed, 49 insertions(+), 7 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 5ab63d1de9..329961a25d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1169,7 +1169,7 @@ class BinaryLogisticRegressionSummary private[classification] (
   // TODO: Allow the user to vary the number of bins using a setBins method in
   // BinaryClassificationMetrics. For now the default is set to 100.
   @transient private val binaryMetrics = new BinaryClassificationMetrics(
-    predictions.select(probabilityCol, labelCol).rdd.map {
+    predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map {
       case Row(score: Vector, label: Double) => (score(1), label)
     }, 100
   )
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 02b27fb650..bb9e150c49 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -992,7 +992,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
     } else {
       link.unlink(0.0)
     }
-    predictions.select(col(model.getLabelCol), w).rdd.map {
+    predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map {
       case Row(y: Double, weight: Double) =>
         family.deviance(y, wtdmu, weight)
     }.sum()
@@ -1004,7 +1004,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
   @Since("2.0.0")
   lazy val deviance: Double = {
     val w = weightCol
-    predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
+    predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
       case Row(label: Double, pred: Double, weight: Double) =>
         family.deviance(label, pred, weight)
     }.sum()
@@ -1030,9 +1030,10 @@ class GeneralizedLinearRegressionSummary private[regression] (
   lazy val aic: Double = {
     val w = weightCol
     val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0)
-    val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
-      case Row(label: Double, pred: Double, weight: Double) =>
-        (label, pred, weight)
+    val t = predictions.select(
+      col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
+        case Row(label: Double, pred: Double, weight: Double) =>
+          (label, pred, weight)
     }
     family.aic(t, deviance, numInstances, weightSum) + 2 * rank
   }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 8451e60144..42b56754e0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -32,7 +32,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.LongType
 
 class LogisticRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1776,6 +1777,21 @@ class LogisticRegressionSuite
       summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
   }
 
+  test("evaluate with labels that are not doubles") {
+    // Evaluate a test set with Label that is a numeric type other than Double
+    val lr = new LogisticRegression()
+      .setMaxIter(1)
+      .setRegParam(1.0)
+    val model = lr.fit(smallBinaryDataset)
+    val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+
+    val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType),
+      col(model.getFeaturesCol))
+    val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]
+
+    assert(summary.areaUnderROC === longSummary.areaUnderROC)
+  }
+
   test("statistics on training data") {
     // Test that loss is monotonically decreasing.
     val lr = new LogisticRegression()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 937aa7d3c2..ac1ef5feb9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.FloatType
 
 class GeneralizedLinearRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1067,6 +1068,30 @@ class GeneralizedLinearRegressionSuite
       idx += 1
     }
   }
+
+  test("evaluate with labels that are not doubles") {
+    // Evaulate with a dataset that contains Labels not as doubles to verify correct casting
+    val dataset = Seq(
+      Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+      Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)),
+      Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)),
+      Instance(29.0, 1.0, Vectors.dense(3.0, 13.0))
+    ).toDF()
+
+    val trainer = new GeneralizedLinearRegression()
+      .setMaxIter(1)
+    val model = trainer.fit(dataset)
+    assert(model.hasSummary)
+    val summary = model.summary
+
+    val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType),
+      col(model.getFeaturesCol))
+    val evalSummary = model.evaluate(longLabelDataset)
+    // The calculations below involve pattern matching with Label as a double
+    assert(evalSummary.nullDeviance === summary.nullDeviance)
+    assert(evalSummary.deviance === summary.deviance)
+    assert(evalSummary.aic === summary.aic)
+  }
 }
 
 object GeneralizedLinearRegressionSuite {
-- 
GitLab