diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 9c4c59a5e60fa54bc960dd197d9d8e89d0edb6e8..f8bcbeedfb042e5cd653551938e21b237917b31a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -359,8 +359,16 @@ class LogisticRegressionSuite
         assert(pred == predFromProb)
     }
 
-    // force it to use probability2prediction
+    // force it to use raw2prediction
     model.setProbabilityCol("")
+    val resultsUsingRaw2Predict =
+      model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
+    resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
+      case (pred1, pred2) => assert(pred1 === pred2)
+    }
+
+    // force it to use probability2prediction
+    model.setRawPredictionCol("")
     val resultsUsingProb2Predict =
       model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
     resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
@@ -405,8 +413,16 @@ class LogisticRegressionSuite
         assert(pred == predFromProb)
     }
 
-    // force it to use probability2prediction
+    // force it to use raw2prediction
     model.setProbabilityCol("")
+    val resultsUsingRaw2Predict =
+      model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
+    resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
+      case (pred1, pred2) => assert(pred1 === pred2)
+    }
+
+    // force it to use probability2prediction
+    model.setRawPredictionCol("")
     val resultsUsingProb2Predict =
       model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
     resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach {