From 248916f5589155c0c3e93c3874781f17b08d598d Mon Sep 17 00:00:00 2001 From: Sean Owen <sowen@cloudera.com> Date: Sat, 24 Sep 2016 08:15:55 +0100 Subject: [PATCH] [SPARK-17057][ML] ProbabilisticClassifierModels' thresholds should have at most one 0 ## What changes were proposed in this pull request? Match ProbabilisticClassifer.thresholds requirements to R randomForest cutoff, requiring all > 0 ## How was this patch tested? Jenkins tests plus new test cases Author: Sean Owen <sowen@cloudera.com> Closes #15149 from srowen/SPARK-17057. --- .../classification/LogisticRegression.scala | 5 +-- .../ProbabilisticClassifier.scala | 20 +++++------ .../ml/param/shared/SharedParamsCodeGen.scala | 8 +++-- .../spark/ml/param/shared/sharedParams.scala | 4 +-- .../ProbabilisticClassifierSuite.scala | 35 +++++++++++++++---- .../ml/param/_shared_params_code_gen.py | 5 +-- python/pyspark/ml/param/shared.py | 4 +-- 7 files changed, 52 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 343d50c790..5ab63d1de9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -123,9 +123,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * Set thresholds in multiclass (or binary) classification to adjust the probability of - * predicting each class. Array must have length equal to the number of classes, with values >= 0. + * predicting each class. Array must have length equal to the number of classes, with values > 0, + * excepting that at most one value may be 0. * The class with largest value p/t is predicted, where p is the original probability of that - * class and t is the class' threshold. + * class and t is the class's threshold. * * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared. * If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 1b6e77542c..e89da6ff8b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -200,22 +200,20 @@ abstract class ProbabilisticClassificationModel[ if (!isDefined(thresholds)) { probability.argmax } else { - val thresholds: Array[Double] = getThresholds - val probabilities = probability.toArray + val thresholds = getThresholds var argMax = 0 var max = Double.NegativeInfinity var i = 0 val probabilitySize = probability.size while (i < probabilitySize) { - if (thresholds(i) == 0.0) { - max = Double.PositiveInfinity + // Thresholds are all > 0, excepting that at most one may be 0. + // The single class whose threshold is 0, if any, will always be predicted + // ('scaled' = +Infinity). However in the case that this class also has + // 0 probability, the class will not be selected ('scaled' is NaN). + val scaled = probability(i) / thresholds(i) + if (scaled > max) { + max = scaled argMax = i - } else { - val scaled = probabilities(i) / thresholds(i) - if (scaled > max) { - max = scaled - argMax = i - } } i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 480b03d0f3..c94b8b4e9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -50,10 +50,12 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + + " Array must have length equal to the number of classes, with values > 0" + + " excepting that at most one value may be 0." + " The class with largest value p/t is predicted, where p is the original probability" + - " of that class and t is the class' threshold", - isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false), + " of that class and t is the class's threshold", + isValid = "(t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1", + finalMethods = false), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 9125d9e19b..fa4530927e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -176,10 +176,10 @@ private[ml] trait HasThreshold extends Params { private[ml] trait HasThresholds extends Params { /** - * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. * @group param */ - final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0)) + final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1) /** @group getParam */ def getThresholds: Array[Double] = $(thresholds) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index b3bd2b3e57..172c64aab9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -36,8 +36,8 @@ final class TestProbabilisticClassificationModel( rawPrediction } - def friendlyPredict(input: Vector): Double = { - predict(input) + def friendlyPredict(values: Double*): Double = { + predict(Vectors.dense(values.toArray)) } } @@ -45,16 +45,37 @@ final class TestProbabilisticClassificationModel( class ProbabilisticClassifierSuite extends SparkFunSuite { test("test thresholding") { - val thresholds = Array(0.5, 0.2) val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - .setThresholds(thresholds) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) + .setThresholds(Array(0.5, 0.2)) + assert(testModel.friendlyPredict(1.0, 1.0) === 1.0) + assert(testModel.friendlyPredict(1.0, 0.2) === 0.0) } test("test thresholding not required") { val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) + assert(testModel.friendlyPredict(1.0, 2.0) === 1.0) + } + + test("test tiebreak") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + .setThresholds(Array(0.4, 0.4)) + assert(testModel.friendlyPredict(0.6, 0.6) === 0.0) + } + + test("test one zero threshold") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + .setThresholds(Array(0.0, 0.1)) + assert(testModel.friendlyPredict(1.0, 10.0) === 0.0) + assert(testModel.friendlyPredict(0.0, 10.0) === 1.0) + } + + test("bad thresholds") { + intercept[IllegalArgumentException] { + new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(0.0, 0.0)) + } + intercept[IllegalArgumentException] { + new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1)) + } } } diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 4f4328bcad..929591236d 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -139,8 +139,9 @@ if __name__ == "__main__": "model.", "True", "TypeConverters.toBoolean"), ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + "predicting each class. Array must have length equal to the number of classes, with " + - "values >= 0. The class with largest value p/t is predicted, where p is the original " + - "probability of that class and t is the class' threshold.", None, + "values > 0, excepting that at most one value may be 0. " + + "The class with largest value p/t is predicted, where p is the original " + + "probability of that class and t is the class's threshold.", None, "TypeConverters.toListFloat"), ("weightCol", "weight column name. If this is not set or empty, we treat " + "all instance weights as 1.0.", None, "TypeConverters.toString"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 24af07afc7..cc596936d8 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -469,10 +469,10 @@ class HasStandardization(Params): class HasThresholds(Params): """ - Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold. """ - thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat) + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat) def __init__(self): super(HasThresholds, self).__init__() -- GitLab