diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala index 8343f281390ac85934e2304f4626e5a76c3c4710..36cda721ddbe268e909ca587d8b3ebafc5c353ed 100644 --- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala @@ -15,7 +15,7 @@ class RidgeRegressionModel( val weights: DoubleMatrix, val intercept: Double, val lambdaOpt: Double, - val lambdas: List[(Double, Double, DoubleMatrix)]) + val lambdas: Seq[(Double, Double, DoubleMatrix)]) extends RegressionModel { override def predict(testData: RDD[Array[Double]]): RDD[Double] = { @@ -99,12 +99,10 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) } // Binary search for the best assignment to lambda. - def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = { + def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = { @tailrec - def loop( - low: Double, - high: Double, - acc: List[(Double, Double, DoubleMatrix)]): List[(Double, Double, DoubleMatrix)] = { + def loop(low: Double, high: Double, acc: Seq[(Double, Double, DoubleMatrix)]) + : Seq[(Double, Double, DoubleMatrix)] = { val mid = (high - low) / 2 + low val lowValue = crossValidate((mid - low) / 2 + low) val highValue = crossValidate((high - mid) / 2 + mid) @@ -114,14 +112,13 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) (mid - (high-low)/4, high) } if (newHigh - newLow > 1.0E-7) { - // :: is list prepend in Scala. - loop(newLow, newHigh, lowValue :: highValue :: acc) + loop(newLow, newHigh, acc :+ lowValue :+ highValue) } else { - lowValue :: highValue :: acc + acc :+ lowValue :+ highValue } } - loop(low, high, Nil) + loop(low, high, Vector.empty) } // Actually compute the best lambda @@ -143,6 +140,7 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) model } } + /** * Top-level methods for calling Ridge Regression. */