Skip to content
Snippets Groups Projects
Commit 744da8ee authored by shivaram's avatar shivaram
Browse files

Merge pull request #679 from ryanlecompte/master

Make binSearch method tail-recursive for RidgeRegression
parents 3cc6818f be123aa6
No related branches found
No related tags found
No related merge requests found
package spark.mllib.regression
import spark.{Logging, RDD, SparkContext}
import spark.SparkContext._
import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix
import org.jblas.Solve
import scala.annotation.tailrec
import scala.collection.mutable
/**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
*/
......@@ -14,7 +16,7 @@ class RidgeRegressionModel(
val weights: DoubleMatrix,
val intercept: Double,
val lambdaOpt: Double,
val lambdas: List[(Double, Double, DoubleMatrix)])
val lambdas: Seq[(Double, Double, DoubleMatrix)])
extends RegressionModel {
override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
......@@ -98,21 +100,29 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
}
// Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
(low, mid + (high-low)/4)
} else {
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
// :: is list prepend in Scala.
lowValue :: highValue :: binSearch(newLow, newHigh)
} else {
List(lowValue, highValue)
def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
val buffer = mutable.ListBuffer.empty[(Double, Double, DoubleMatrix)]
@tailrec
def loop(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
(low, mid + (high-low)/4)
} else {
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
buffer += lowValue += highValue
loop(newLow, newHigh)
} else {
buffer += lowValue += highValue
buffer.result()
}
}
loop(low, high)
}
// Actually compute the best lambda
......@@ -134,6 +144,7 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
model
}
}
/**
* Top-level methods for calling Ridge Regression.
*/
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment