Skip to content
Snippets Groups Projects
Commit 75f37573 authored by Xinghao's avatar Xinghao
Browse files

Fix rounding error in LogisticRegression.scala

parent c823ee1e
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,8 @@ import spark.{Logging, RDD, SparkContext}
import spark.mllib.optimization._
import spark.mllib.util.MLUtils
import scala.math.round
import org.jblas.DoubleMatrix
/**
......@@ -42,14 +44,14 @@ class LogisticRegressionModel(
val localIntercept = intercept
testData.map { x =>
val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept
(1.0/ (1.0 + math.exp(margin * -1))).toInt
round(1.0/ (1.0 + math.exp(margin * -1))).toInt
}
}
override def predict(testData: Array[Double]): Int = {
val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept
(1.0/ (1.0 + math.exp(margin * -1))).toInt
round(1.0/ (1.0 + math.exp(margin * -1))).toInt
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment