diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index cffa1ab700f806d745ad7196fd978f14829ef3a3..ab54cb06d5aab2134b9d88bc9d4f2b70f1ba0bef 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
 
 import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
 import breeze.stats.distributions.{Multinomial => BrzMultinomial}
+import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -103,17 +104,24 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
       piData: Array[Double],
       thetaData: Array[Array[Double]],
       model: NaiveBayesModel): Unit = {
-    def closeFit(d1: Double, d2: Double, precision: Double): Boolean = {
-      (d1 - d2).abs <= precision
-    }
-    val modelIndex = (0 until piData.length).zip(model.labels.map(_.toInt))
-    for (i <- modelIndex) {
-      assert(closeFit(math.exp(piData(i._2)), math.exp(model.pi(i._1)), 0.05))
-    }
-    for (i <- modelIndex) {
-      for (j <- 0 until thetaData(i._2).length) {
-        assert(closeFit(math.exp(thetaData(i._2)(j)), math.exp(model.theta(i._1)(j)), 0.05))
+    val modelIndex = piData.indices.zip(model.labels.map(_.toInt))
+    try {
+      for (i <- modelIndex) {
+        assert(math.exp(piData(i._2)) ~== math.exp(model.pi(i._1)) absTol 0.05)
+        for (j <- thetaData(i._2).indices) {
+          assert(math.exp(thetaData(i._2)(j)) ~== math.exp(model.theta(i._1)(j)) absTol 0.05)
+        }
       }
+    } catch {
+      case e: TestFailedException =>
+        def arr2str(a: Array[Double]): String = a.mkString("[", ", ", "]")
+        def msg(orig: String): String = orig + "\nvalidateModelFit:\n" +
+          " piData: " + arr2str(piData) + "\n" +
+          " thetaData: " + thetaData.map(arr2str).mkString("\n") + "\n" +
+          " model.labels: " + arr2str(model.labels) + "\n" +
+          " model.pi: " + arr2str(model.pi) + "\n" +
+          " model.theta: " + model.theta.map(arr2str).mkString("\n")
+        throw e.modifyMessage(_.map(msg))
     }
   }