Skip to content
Snippets Groups Projects
Commit 757e56df authored by ryanlecompte's avatar ryanlecompte
Browse files

make binSearch a tail-recursive method

parent bf1311e6
No related branches found
No related tags found
No related merge requests found
package spark.mllib.regression package spark.mllib.regression
import spark.{Logging, RDD, SparkContext} import spark.{Logging, RDD, SparkContext}
import spark.SparkContext._
import spark.mllib.util.MLUtils import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix import org.jblas.DoubleMatrix
import org.jblas.Solve import org.jblas.Solve
import scala.annotation.tailrec
/** /**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase * Ridge Regression from Joseph Gonzalez's implementation in MLBase
*/ */
...@@ -99,20 +100,28 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) ...@@ -99,20 +100,28 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
// Binary search for the best assignment to lambda. // Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = { def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low @tailrec
val lowValue = crossValidate((mid - low) / 2 + low) def loop(
val highValue = crossValidate((high - mid) / 2 + mid) low: Double,
val (newLow, newHigh) = if (lowValue._2 < highValue._2) { high: Double,
(low, mid + (high-low)/4) acc: List[(Double, Double, DoubleMatrix)]): List[(Double, Double, DoubleMatrix)] = {
} else { val mid = (high - low) / 2 + low
(mid - (high-low)/4, high) val lowValue = crossValidate((mid - low) / 2 + low)
} val highValue = crossValidate((high - mid) / 2 + mid)
if (newHigh - newLow > 1.0E-7) { val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
// :: is list prepend in Scala. (low, mid + (high-low)/4)
lowValue :: highValue :: binSearch(newLow, newHigh) } else {
} else { (mid - (high-low)/4, high)
List(lowValue, highValue) }
if (newHigh - newLow > 1.0E-7) {
// :: is list prepend in Scala.
loop(newLow, newHigh, lowValue :: highValue :: acc)
} else {
lowValue :: highValue :: acc
}
} }
loop(low, high, Nil)
} }
// Actually compute the best lambda // Actually compute the best lambda
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment