From 9cff67f3465bc6ffe1b5abee9501e3c17f8fd194 Mon Sep 17 00:00:00 2001 From: Yanbo Liang <ybliang8@gmail.com> Date: Wed, 28 Dec 2016 01:24:18 -0800 Subject: [PATCH] [MINOR][ML] Correct test cases of LoR raw2prediction & probability2prediction. ## What changes were proposed in this pull request? Correct test cases of ```LogisticRegression``` raw2prediction & probability2prediction. ## How was this patch tested? Changed unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #16407 from yanboliang/raw-probability. --- .../LogisticRegressionSuite.scala | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9c4c59a5e6..f8bcbeedfb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -359,8 +359,16 @@ class LogisticRegressionSuite assert(pred == predFromProb) } - // force it to use probability2prediction + // force it to use raw2prediction model.setProbabilityCol("") + val resultsUsingRaw2Predict = + model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() + resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + model.setRawPredictionCol("") val resultsUsingProb2Predict = model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { @@ -405,8 +413,16 @@ class LogisticRegressionSuite assert(pred == predFromProb) } - // force it to use probability2prediction + // force it to use raw2prediction model.setProbabilityCol("") + val resultsUsingRaw2Predict = + model.transform(smallBinaryDataset).select("prediction").as[Double].collect() + resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + model.setRawPredictionCol("") val resultsUsingProb2Predict = model.transform(smallBinaryDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { -- GitLab