From 07f17439a52b65d4f5ef8c8d80bc25dadc0182a8 Mon Sep 17 00:00:00 2001
From: Xinghao <pxinghao@gmail.com>
Date: Mon, 29 Jul 2013 09:22:31 -0700
Subject: [PATCH] Fix validatePrediction functions for Classification models

Classifiers return categorical (Int) values that should be compared
directly
---
 .../spark/mllib/classification/LogisticRegressionSuite.scala   | 3 +--
 mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala | 3 +--
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
index 3aa9fe6d12..d3fe58a382 100644
--- a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -62,8 +62,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
 
   def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) {
     val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
-      // A prediction is off if the prediction is more than 0.5 away from expected value.
-      math.abs(prediction.toDouble - expected.toDouble) > 0.5
+      (prediction != expected)
     }.size
     // At least 80% of the predictions should be on.
     assert(numOffPredictions < input.length / 5)
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
index 3f00398a0a..d546e0729e 100644
--- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -52,8 +52,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
 
   def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) {
     val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
-      // A prediction is off if the prediction is more than 0.5 away from expected value.
-      math.abs(prediction - expected) > 0.5
+      (prediction != expected)
     }.size
     // At least 80% of the predictions should be on.
     assert(numOffPredictions < input.length / 5)
-- 
GitLab