diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 918ab27e2730bbffd3d0a94d24f7e76f3f293a81..98c879ece62d624d874c01ec72fa26932909f406 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -262,6 +262,9 @@ class DecisionTreeClassifierSuite assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, "probability prediction mismatch") } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, DecisionTreeClassificationModel](newTree, newData) } test("training with 1-category categorical feature") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 1f79e0d4e6228a311b5c955fa5456bc606725cbc..8000143d4d142c3940ddbeee8681c3bfe2bd1210 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -219,6 +219,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { case (pred1, pred2) => assert(pred1 === pred2) } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, GBTClassificationModel](gbtModel, validationDataset) } test("GBT parameter stepSize should be in interval (0, 1]") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 6bf1253b718d161145f5229b9b277f2944ac56b9..d43c7cdbde62cae072a29e2e327c3762b5767b44 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -502,6 +502,9 @@ class LogisticRegressionSuite resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { case (pred1, pred2) => assert(pred1 === pred2) } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, LogisticRegressionModel](model, smallMultinomialDataset) } test("binary logistic regression: Predictor, Classifier methods") { @@ -556,6 +559,9 @@ class LogisticRegressionSuite resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach { case (pred1, pred2) => assert(pred1 === pred2) } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, LogisticRegressionModel](model, smallBinaryDataset) } test("coefficients and intercept methods") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index c294e4ad54bf740700f93269c04cad7dc0c6e275..d3141ec7085604b27f02bee81f4d76fb873a7e05 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -104,6 +104,8 @@ class MultilayerPerceptronClassifierSuite case Row(p: Vector, e: Vector) => assert(p ~== e absTol 1e-3) } + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, MultilayerPerceptronClassificationModel](model, strongDataset) } test("test model probability") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 3a2be236f1257a59cf5583f65a54f71836df91ab..9730dd68a3b27c3ee20f2557e44e85ab6ab40971 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -160,6 +160,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val featureAndProbabilities = model.transform(validationDataset) .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "multinomial") + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, NaiveBayesModel](model, testDataset) } test("Naive Bayes with weighted samples") { @@ -213,6 +216,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val featureAndProbabilities = model.transform(validationDataset) .select("features", "probability") validateProbabilities(featureAndProbabilities, model, "bernoulli") + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, NaiveBayesModel](model, testDataset) } test("detect negative values") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index 172c64aab9d3dbd7aa83784e74fb33ef4be26883..4ecd5a05365eb98ff7c0f72762e48c7fd3f7734a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.sql.{Dataset, Row} final class TestProbabilisticClassificationModel( override val uid: String, @@ -91,4 +94,61 @@ object ProbabilisticClassifierSuite { "thresholds" -> Array(0.4, 0.6) ) + /** + * Helper for testing that a ProbabilisticClassificationModel computes + * the same predictions across all combinations of output columns + * (rawPrediction/probability/prediction) turned on/off. Makes sure the + * output column values match by comparing vs. the case with all 3 output + * columns turned on. + */ + def testPredictMethods[ + FeaturesType, + M <: ProbabilisticClassificationModel[FeaturesType, M]]( + model: M, testData: Dataset[_]): Unit = { + + val allColModel = model.copy(ParamMap.empty) + .setRawPredictionCol("rawPredictionAll") + .setProbabilityCol("probabilityAll") + .setPredictionCol("predictionAll") + val allColResult = allColModel.transform(testData) + + for (rawPredictionCol <- Seq("", "rawPredictionSingle")) { + for (probabilityCol <- Seq("", "probabilitySingle")) { + for (predictionCol <- Seq("", "predictionSingle")) { + val newModel = model.copy(ParamMap.empty) + .setRawPredictionCol(rawPredictionCol) + .setProbabilityCol(probabilityCol) + .setPredictionCol(predictionCol) + + val result = newModel.transform(allColResult) + + import org.apache.spark.sql.functions._ + + val resultRawPredictionCol = + if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol) + val resultProbabilityCol = + if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol) + val resultPredictionCol = + if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol) + + result.select( + resultRawPredictionCol, col("rawPredictionAll"), + resultProbabilityCol, col("probabilityAll"), + resultPredictionCol, col("predictionAll") + ).collect().foreach { + case Row( + rawPredictionSingle: Vector, rawPredictionAll: Vector, + probabilitySingle: Vector, probabilityAll: Vector, + predictionSingle: Double, predictionAll: Double + ) => { + assert(rawPredictionSingle ~== rawPredictionAll relTol 1E-3) + assert(probabilitySingle ~== probabilityAll relTol 1E-3) + assert(predictionSingle ~== predictionAll relTol 1E-3) + } + } + } + } + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ca2954d2f32c4018cff3b45faabdeab06ab3cccf..2cca2e6c046983c919f77d7a60848bcdcb304fb1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -155,6 +155,8 @@ class RandomForestClassifierSuite "probability prediction mismatch") assert(probPred.toArray.sum ~== 1.0 relTol 1E-5) } + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, RandomForestClassificationModel](model, df) } test("Fitting without numClasses in metadata") {