diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 343d50c790e85ff9ef42045c390ae4cae6c52269..5ab63d1de95d346b3ecc0b7f817fe106826f1d5e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -123,9 +123,10 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
 
   /**
    * Set thresholds in multiclass (or binary) classification to adjust the probability of
-   * predicting each class. Array must have length equal to the number of classes, with values >= 0.
+   * predicting each class. Array must have length equal to the number of classes, with values > 0,
+   * excepting that at most one value may be 0.
    * The class with largest value p/t is predicted, where p is the original probability of that
-   * class and t is the class' threshold.
+   * class and t is the class's threshold.
    *
    * Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared.
    *       If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 1b6e77542cc8056fff1a32d40a6fe17cdb3f899d..e89da6ff8bdd72b071611aedfdc7cb1f339dca8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -200,22 +200,20 @@ abstract class ProbabilisticClassificationModel[
     if (!isDefined(thresholds)) {
       probability.argmax
     } else {
-      val thresholds: Array[Double] = getThresholds
-      val probabilities = probability.toArray
+      val thresholds = getThresholds
       var argMax = 0
       var max = Double.NegativeInfinity
       var i = 0
       val probabilitySize = probability.size
       while (i < probabilitySize) {
-        if (thresholds(i) == 0.0) {
-          max = Double.PositiveInfinity
+        // Thresholds are all > 0, excepting that at most one may be 0.
+        // The single class whose threshold is 0, if any, will always be predicted
+        // ('scaled' = +Infinity). However in the case that this class also has
+        // 0 probability, the class will not be selected ('scaled' is NaN).
+        val scaled = probability(i) / thresholds(i)
+        if (scaled > max) {
+          max = scaled
           argMax = i
-        } else {
-          val scaled = probabilities(i) / thresholds(i)
-          if (scaled > max) {
-            max = scaled
-            argMax = i
-          }
         }
         i += 1
       }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 480b03d0f35c45a600fc36513e8df29d470972a0..c94b8b4e9dfda35347ad2564404d60402edceb8f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -50,10 +50,12 @@ private[shared] object SharedParamsCodeGen {
         isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
       ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
         " to adjust the probability of predicting each class." +
-        " Array must have length equal to the number of classes, with values >= 0." +
+        " Array must have length equal to the number of classes, with values > 0" +
+        " excepting that at most one value may be 0." +
         " The class with largest value p/t is predicted, where p is the original probability" +
-        " of that class and t is the class' threshold",
-        isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false),
+        " of that class and t is the class's threshold",
+        isValid = "(t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1",
+        finalMethods = false),
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 9125d9e19bf0961aeef5a059c93e859f1f81fce9..fa4530927e8b04c5154372a6fa96e3896d629077 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -176,10 +176,10 @@ private[ml] trait HasThreshold extends Params {
 private[ml] trait HasThresholds extends Params {
 
   /**
-   * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
+   * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
    * @group param
    */
-  final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
+  final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1)
 
   /** @group getParam */
   def getThresholds: Array[Double] = $(thresholds)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index b3bd2b3e57b36ba51da2af62b40d812a32061b69..172c64aab9d3dbd7aa83784e74fb33ef4be26883 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -36,8 +36,8 @@ final class TestProbabilisticClassificationModel(
     rawPrediction
   }
 
-  def friendlyPredict(input: Vector): Double = {
-    predict(input)
+  def friendlyPredict(values: Double*): Double = {
+    predict(Vectors.dense(values.toArray))
   }
 }
 
@@ -45,16 +45,37 @@ final class TestProbabilisticClassificationModel(
 class ProbabilisticClassifierSuite extends SparkFunSuite {
 
   test("test thresholding") {
-    val thresholds = Array(0.5, 0.2)
     val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
-      .setThresholds(thresholds)
-    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
-    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
+      .setThresholds(Array(0.5, 0.2))
+    assert(testModel.friendlyPredict(1.0, 1.0) === 1.0)
+    assert(testModel.friendlyPredict(1.0, 0.2) === 0.0)
   }
 
   test("test thresholding not required") {
     val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
-    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
+    assert(testModel.friendlyPredict(1.0, 2.0) === 1.0)
+  }
+
+  test("test tiebreak") {
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+      .setThresholds(Array(0.4, 0.4))
+    assert(testModel.friendlyPredict(0.6, 0.6) === 0.0)
+  }
+
+  test("test one zero threshold") {
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+      .setThresholds(Array(0.0, 0.1))
+    assert(testModel.friendlyPredict(1.0, 10.0) === 0.0)
+    assert(testModel.friendlyPredict(0.0, 10.0) === 1.0)
+  }
+
+  test("bad thresholds") {
+    intercept[IllegalArgumentException] {
+      new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(0.0, 0.0))
+    }
+    intercept[IllegalArgumentException] {
+      new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1))
+    }
   }
 }
 
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 4f4328bcadc6f75b315120d7895a12445e5ecda2..929591236d688130f6f97d88004c779c309dff23 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -139,8 +139,9 @@ if __name__ == "__main__":
          "model.", "True", "TypeConverters.toBoolean"),
         ("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
          "predicting each class. Array must have length equal to the number of classes, with " +
-         "values >= 0. The class with largest value p/t is predicted, where p is the original " +
-         "probability of that class and t is the class' threshold.", None,
+         "values > 0, excepting that at most one value may be 0. " +
+         "The class with largest value p/t is predicted, where p is the original " +
+         "probability of that class and t is the class's threshold.", None,
          "TypeConverters.toListFloat"),
         ("weightCol", "weight column name. If this is not set or empty, we treat " +
          "all instance weights as 1.0.", None, "TypeConverters.toString"),
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 24af07afc7d5c521d952da14080bcf32250e25cc..cc596936d82f6197b9ab64f1175de2a549973dca 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -469,10 +469,10 @@ class HasStandardization(Params):
 
 class HasThresholds(Params):
     """
-    Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.
+    Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.
     """
 
-    thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", typeConverter=TypeConverters.toListFloat)
+    thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold.", typeConverter=TypeConverters.toListFloat)
 
     def __init__(self):
         super(HasThresholds, self).__init__()