-
ryanlecompte authoredryanlecompte authored
RidgeRegression.scala 6.37 KiB
package spark.mllib.regression
import spark.{Logging, RDD, SparkContext}
import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix
import org.jblas.Solve
import scala.annotation.tailrec
import scala.collection.mutable
/**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
*/
class RidgeRegressionModel(
val weights: DoubleMatrix,
val intercept: Double,
val lambdaOpt: Double,
val lambdas: Seq[(Double, Double, DoubleMatrix)])
extends RegressionModel {
override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
testData.map { x =>
(new DoubleMatrix(1, x.length, x:_*).mmul(this.weights)).get(0) + this.intercept
}
}
override def predict(testData: Array[Double]): Double = {
(new DoubleMatrix(1, testData.length, testData:_*).mmul(this.weights)).get(0) + this.intercept
}
}
class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
extends Logging {
def this() = this(0.0, 100.0)
/**
* Set the lower bound on binary search for lambda's. Default is 0.
*/
def setLowLambda(low: Double) = {
this.lambdaLow = low
this
}
/**
* Set the upper bound on binary search for lambda's. Default is 100.0.
*/
def setHighLambda(hi: Double) = {
this.lambdaHigh = hi
this
}
def train(input: RDD[(Double, Array[Double])]): RidgeRegressionModel = {
val nfeatures: Int = input.take(1)(0)._2.length
val nexamples: Long = input.count()
val (yMean, xColMean, xColSd) = MLUtils.computeStats(input, nfeatures, nexamples)
val data = input.map { case(y, features) =>
val yNormalized = y - yMean
val featuresMat = new DoubleMatrix(nfeatures, 1, features:_*)
val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd)
(yNormalized, featuresNormalized.toArray)
}
// Compute XtX - Size of XtX is nfeatures by nfeatures
val XtX: DoubleMatrix = data.map { case (y, features) =>
val x = new DoubleMatrix(1, features.length, features:_*)
x.transpose().mmul(x)
}.reduce(_.addi(_))
// Compute Xt*y - Size of Xty is nfeatures by 1
val Xty: DoubleMatrix = data.map { case (y, features) =>
new DoubleMatrix(features.length, 1, features:_*).mul(y)
}.reduce(_.addi(_))
// Define a function to compute the leave one out cross validation error
// for a single example
def crossValidate(lambda: Double): (Double, Double, DoubleMatrix) = {
// Compute the MLE ridge regression parameter value
// Ridge Regression parameter = inv(XtX + \lambda*I) * Xty
val XtXlambda = DoubleMatrix.eye(nfeatures).muli(lambda).addi(XtX)
val w = Solve.solveSymmetric(XtXlambda, Xty)
val invXtX = Solve.solveSymmetric(XtXlambda, DoubleMatrix.eye(nfeatures))
// compute the generalized cross validation score
val cverror = data.map {
case (y, features) =>
val x = new DoubleMatrix(features.length, 1, features:_*)
val yhat = w.transpose().mmul(x).get(0)
val H_ii = x.transpose().mmul(invXtX).mmul(x).get(0)
val residual = (y - yhat) / (1.0 - H_ii)
residual * residual
}.reduce(_ + _) / nexamples
(lambda, cverror, w)
}
// Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
val buffer = mutable.ListBuffer.empty[(Double, Double, DoubleMatrix)]
@tailrec
def loop(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
(low, mid + (high-low)/4)
} else {
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
buffer += lowValue += highValue
loop(newLow, newHigh)
} else {
buffer += lowValue += highValue
buffer.result()
}
}
loop(low, high)
}
// Actually compute the best lambda
val lambdas = binSearch(lambdaLow, lambdaHigh).sortBy(_._1)
// Find the best parameter set by taking the lowest cverror.
val (lambdaOpt, cverror, weights) = lambdas.reduce((a, b) => if (a._2 < b._2) a else b)
// Return the model which contains the solution
val weightsScaled = weights.div(xColSd)
val intercept = yMean - (weights.transpose().mmul(xColMean.div(xColSd)).get(0))
val model = new RidgeRegressionModel(weightsScaled, intercept, lambdaOpt, lambdas)
logInfo("RidgeRegression: optimal lambda " + model.lambdaOpt)
logInfo("RidgeRegression: optimal weights " + model.weights)
logInfo("RidgeRegression: optimal intercept " + model.intercept)
logInfo("RidgeRegression: cross-validation error " + cverror)
model
}
}
/**
* Top-level methods for calling Ridge Regression.
*/
object RidgeRegression {
/**
* Train a ridge regression model given an RDD of (response, features) pairs.
* We use the closed form solution to compute the cross-validation score for
* a given lambda. The optimal lambda is computed by performing binary search
* between the provided bounds of lambda.
*
* @param input RDD of (response, array of features) pairs.
* @param lambdaLow lower bound used in binary search for lambda
* @param lambdaHigh upper bound used in binary search for lambda
*/
def train(
input: RDD[(Double, Array[Double])],
lambdaLow: Double,
lambdaHigh: Double)
: RidgeRegressionModel =
{
new RidgeRegression(lambdaLow, lambdaHigh).train(input)
}
/**
* Train a ridge regression model given an RDD of (response, features) pairs.
* We use the closed form solution to compute the cross-validation score for
* a given lambda. The optimal lambda is computed by performing binary search
* between lambda values of 0 and 100.
*
* @param input RDD of (response, array of features) pairs.
*/
def train(input: RDD[(Double, Array[Double])]) : RidgeRegressionModel = {
train(input, 0.0, 100.0)
}
def main(args: Array[String]) {
if (args.length != 2) {
println("Usage: RidgeRegression <master> <input_dir>")
System.exit(1)
}
val sc = new SparkContext(args(0), "RidgeRegression")
val data = MLUtils.loadData(sc, args(1))
val model = RidgeRegression.train(data, 0, 1000)
sc.stop()
}
}