diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala index a6ececbeb67f16a5806d1341ca052649df5a4706..f66025bc0bb97f521a79009df4fba9f0734b9ab1 100644 --- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala @@ -1,12 +1,14 @@ package spark.mllib.regression import spark.{Logging, RDD, SparkContext} -import spark.SparkContext._ import spark.mllib.util.MLUtils import org.jblas.DoubleMatrix import org.jblas.Solve +import scala.annotation.tailrec +import scala.collection.mutable + /** * Ridge Regression from Joseph Gonzalez's implementation in MLBase */ @@ -14,7 +16,7 @@ class RidgeRegressionModel( val weights: DoubleMatrix, val intercept: Double, val lambdaOpt: Double, - val lambdas: List[(Double, Double, DoubleMatrix)]) + val lambdas: Seq[(Double, Double, DoubleMatrix)]) extends RegressionModel { override def predict(testData: RDD[Array[Double]]): RDD[Double] = { @@ -98,21 +100,29 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) } // Binary search for the best assignment to lambda. - def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = { - val mid = (high - low) / 2 + low - val lowValue = crossValidate((mid - low) / 2 + low) - val highValue = crossValidate((high - mid) / 2 + mid) - val (newLow, newHigh) = if (lowValue._2 < highValue._2) { - (low, mid + (high-low)/4) - } else { - (mid - (high-low)/4, high) - } - if (newHigh - newLow > 1.0E-7) { - // :: is list prepend in Scala. - lowValue :: highValue :: binSearch(newLow, newHigh) - } else { - List(lowValue, highValue) + def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = { + val buffer = mutable.ListBuffer.empty[(Double, Double, DoubleMatrix)] + + @tailrec + def loop(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = { + val mid = (high - low) / 2 + low + val lowValue = crossValidate((mid - low) / 2 + low) + val highValue = crossValidate((high - mid) / 2 + mid) + val (newLow, newHigh) = if (lowValue._2 < highValue._2) { + (low, mid + (high-low)/4) + } else { + (mid - (high-low)/4, high) + } + if (newHigh - newLow > 1.0E-7) { + buffer += lowValue += highValue + loop(newLow, newHigh) + } else { + buffer += lowValue += highValue + buffer.result() + } } + + loop(low, high) } // Actually compute the best lambda @@ -134,6 +144,7 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) model } } + /** * Top-level methods for calling Ridge Regression. */