diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9c4c59a5e60fa54bc960dd197d9d8e89d0edb6e8..f8bcbeedfb042e5cd653551938e21b237917b31a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -359,8 +359,16 @@ class LogisticRegressionSuite assert(pred == predFromProb) } - // force it to use probability2prediction + // force it to use raw2prediction model.setProbabilityCol("") + val resultsUsingRaw2Predict = + model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() + resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + model.setRawPredictionCol("") val resultsUsingProb2Predict = model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { @@ -405,8 +413,16 @@ class LogisticRegressionSuite assert(pred == predFromProb) } - // force it to use probability2prediction + // force it to use raw2prediction model.setProbabilityCol("") + val resultsUsingRaw2Predict = + model.transform(smallBinaryDataset).select("prediction").as[Double].collect() + resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + model.setRawPredictionCol("") val resultsUsingProb2Predict = model.transform(smallBinaryDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach {