From 757e56dfc7bd900d5b3f3f145eabe8198bfbe7cc Mon Sep 17 00:00:00 2001 From: ryanlecompte <lecompte@gmail.com> Date: Fri, 5 Jul 2013 19:54:28 -0700 Subject: [PATCH] make binSearch a tail-recursive method --- .../mllib/regression/RidgeRegression.scala | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala index a6ececbeb6..8343f28139 100644 --- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala @@ -1,12 +1,13 @@ package spark.mllib.regression import spark.{Logging, RDD, SparkContext} -import spark.SparkContext._ import spark.mllib.util.MLUtils import org.jblas.DoubleMatrix import org.jblas.Solve +import scala.annotation.tailrec + /** * Ridge Regression from Joseph Gonzalez's implementation in MLBase */ @@ -99,20 +100,28 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) // Binary search for the best assignment to lambda. def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = { - val mid = (high - low) / 2 + low - val lowValue = crossValidate((mid - low) / 2 + low) - val highValue = crossValidate((high - mid) / 2 + mid) - val (newLow, newHigh) = if (lowValue._2 < highValue._2) { - (low, mid + (high-low)/4) - } else { - (mid - (high-low)/4, high) - } - if (newHigh - newLow > 1.0E-7) { - // :: is list prepend in Scala. - lowValue :: highValue :: binSearch(newLow, newHigh) - } else { - List(lowValue, highValue) + @tailrec + def loop( + low: Double, + high: Double, + acc: List[(Double, Double, DoubleMatrix)]): List[(Double, Double, DoubleMatrix)] = { + val mid = (high - low) / 2 + low + val lowValue = crossValidate((mid - low) / 2 + low) + val highValue = crossValidate((high - mid) / 2 + mid) + val (newLow, newHigh) = if (lowValue._2 < highValue._2) { + (low, mid + (high-low)/4) + } else { + (mid - (high-low)/4, high) + } + if (newHigh - newLow > 1.0E-7) { + // :: is list prepend in Scala. + loop(newLow, newHigh, lowValue :: highValue :: acc) + } else { + lowValue :: highValue :: acc + } } + + loop(low, high, Nil) } // Actually compute the best lambda -- GitLab