Skip to content
Snippets Groups Projects
Commit f78f8d0b authored by ryanlecompte's avatar ryanlecompte
Browse files

fix formatting and use Vector instead of List to maintain order

parent 757e56df
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,7 @@ class RidgeRegressionModel( ...@@ -15,7 +15,7 @@ class RidgeRegressionModel(
val weights: DoubleMatrix, val weights: DoubleMatrix,
val intercept: Double, val intercept: Double,
val lambdaOpt: Double, val lambdaOpt: Double,
val lambdas: List[(Double, Double, DoubleMatrix)]) val lambdas: Seq[(Double, Double, DoubleMatrix)])
extends RegressionModel { extends RegressionModel {
override def predict(testData: RDD[Array[Double]]): RDD[Double] = { override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
...@@ -99,12 +99,10 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) ...@@ -99,12 +99,10 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
} }
// Binary search for the best assignment to lambda. // Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = { def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
@tailrec @tailrec
def loop( def loop(low: Double, high: Double, acc: Seq[(Double, Double, DoubleMatrix)])
low: Double, : Seq[(Double, Double, DoubleMatrix)] = {
high: Double,
acc: List[(Double, Double, DoubleMatrix)]): List[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low) val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid) val highValue = crossValidate((high - mid) / 2 + mid)
...@@ -114,14 +112,13 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) ...@@ -114,14 +112,13 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
(mid - (high-low)/4, high) (mid - (high-low)/4, high)
} }
if (newHigh - newLow > 1.0E-7) { if (newHigh - newLow > 1.0E-7) {
// :: is list prepend in Scala. loop(newLow, newHigh, acc :+ lowValue :+ highValue)
loop(newLow, newHigh, lowValue :: highValue :: acc)
} else { } else {
lowValue :: highValue :: acc acc :+ lowValue :+ highValue
} }
} }
loop(low, high, Nil) loop(low, high, Vector.empty)
} }
// Actually compute the best lambda // Actually compute the best lambda
...@@ -143,6 +140,7 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) ...@@ -143,6 +140,7 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
model model
} }
} }
/** /**
* Top-level methods for calling Ridge Regression. * Top-level methods for calling Ridge Regression.
*/ */
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment