Skip to content
Snippets Groups Projects
Commit 757e56df authored by ryanlecompte's avatar ryanlecompte
Browse files

make binSearch a tail-recursive method

parent bf1311e6
No related branches found
No related tags found
No related merge requests found
package spark.mllib.regression
import spark.{Logging, RDD, SparkContext}
import spark.SparkContext._
import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix
import org.jblas.Solve
import scala.annotation.tailrec
/**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
*/
......@@ -99,20 +100,28 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
// Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
(low, mid + (high-low)/4)
} else {
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
// :: is list prepend in Scala.
lowValue :: highValue :: binSearch(newLow, newHigh)
} else {
List(lowValue, highValue)
@tailrec
def loop(
low: Double,
high: Double,
acc: List[(Double, Double, DoubleMatrix)]): List[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
(low, mid + (high-low)/4)
} else {
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
// :: is list prepend in Scala.
loop(newLow, newHigh, lowValue :: highValue :: acc)
} else {
lowValue :: highValue :: acc
}
}
loop(low, high, Nil)
}
// Actually compute the best lambda
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment