diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala index a6ececbeb67f16a5806d1341ca052649df5a4706..8343f281390ac85934e2304f4626e5a76c3c4710 100644 --- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala @@ -1,12 +1,13 @@ package spark.mllib.regression import spark.{Logging, RDD, SparkContext} -import spark.SparkContext._ import spark.mllib.util.MLUtils import org.jblas.DoubleMatrix import org.jblas.Solve +import scala.annotation.tailrec + /** * Ridge Regression from Joseph Gonzalez's implementation in MLBase */ @@ -99,20 +100,28 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) // Binary search for the best assignment to lambda. def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = { - val mid = (high - low) / 2 + low - val lowValue = crossValidate((mid - low) / 2 + low) - val highValue = crossValidate((high - mid) / 2 + mid) - val (newLow, newHigh) = if (lowValue._2 < highValue._2) { - (low, mid + (high-low)/4) - } else { - (mid - (high-low)/4, high) - } - if (newHigh - newLow > 1.0E-7) { - // :: is list prepend in Scala. - lowValue :: highValue :: binSearch(newLow, newHigh) - } else { - List(lowValue, highValue) + @tailrec + def loop( + low: Double, + high: Double, + acc: List[(Double, Double, DoubleMatrix)]): List[(Double, Double, DoubleMatrix)] = { + val mid = (high - low) / 2 + low + val lowValue = crossValidate((mid - low) / 2 + low) + val highValue = crossValidate((high - mid) / 2 + mid) + val (newLow, newHigh) = if (lowValue._2 < highValue._2) { + (low, mid + (high-low)/4) + } else { + (mid - (high-low)/4, high) + } + if (newHigh - newLow > 1.0E-7) { + // :: is list prepend in Scala. + loop(newLow, newHigh, lowValue :: highValue :: acc) + } else { + lowValue :: highValue :: acc + } } + + loop(low, high, Nil) } // Actually compute the best lambda