Skip to content
Snippets Groups Projects
Commit 016787de authored by Evan Sparks's avatar Evan Sparks
Browse files

Merge pull request #863 from shivaram/etrain-ridge

Adding linear regression and refactoring Ridge regression to use SGD
parents 852d8107 b8c50a06
No related branches found
No related tags found
No related merge requests found
Showing
with 755 additions and 297 deletions
......@@ -54,10 +54,17 @@ class LassoWithSGD private (
val gradient = new SquaredGradient()
val updater = new L1Updater()
val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
@transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
// We don't want to penalize the intercept, so set this to false.
setIntercept(false)
var yMean = 0.0
var xColMean: DoubleMatrix = _
var xColSd: DoubleMatrix = _
/**
* Construct a Lasso object with default parameters
......@@ -65,7 +72,35 @@ class LassoWithSGD private (
def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new LassoModel(weights, intercept)
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
new LassoModel(weightsScaled.data, interceptScaled)
}
override def run(
input: RDD[LabeledPoint],
initialWeights: Array[Double])
: LassoModel =
{
val nfeatures: Int = input.first.features.length
val nexamples: Long = input.count()
// To avoid penalizing the intercept, we center and scale the data.
val stats = MLUtils.computeStats(input, nfeatures, nexamples)
yMean = stats._1
xColMean = stats._2
xColSd = stats._3
val normalizedData = input.map { point =>
val yNormalized = point.label - yMean
val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*)
val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd)
LabeledPoint(yNormalized, featuresNormalized.toArray)
}
super.run(normalizedData, initialWeights)
}
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.mllib.regression
import spark.{Logging, RDD, SparkContext}
import spark.mllib.optimization._
import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix
/**
* Regression model trained using LinearRegression.
*
* @param weights Weights computed for every feature.
* @param intercept Intercept computed for this model.
*/
class LinearRegressionModel(
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
dataMatrix.dot(weightMatrix) + intercept
}
}
/**
* Train a regression model with no regularization using Stochastic Gradient Descent.
*/
class LinearRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var miniBatchFraction: Double,
var addIntercept: Boolean)
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
with Serializable {
val gradient = new SquaredGradient()
val updater = new SimpleUpdater()
val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
.setMiniBatchFraction(miniBatchFraction)
/**
* Construct a LinearRegression object with default parameters
*/
def this() = this(1.0, 100, 1.0, true)
def createModel(weights: Array[Double], intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}
/**
* Top-level methods for calling LinearRegression.
*/
object LinearRegressionWithSGD {
/**
* Train a Linear Regression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
* gradient descent are initialized using the initial weights provided.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
initialWeights: Array[Double])
: LinearRegressionModel =
{
new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction, true).run(input,
initialWeights)
}
/**
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration.
*/
def train(
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double)
: LinearRegressionModel =
{
new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction, true).run(input)
}
/**
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
* update the gradient in each iteration.
*
* @param input RDD of (label, array of features) pairs.
* @param stepSize Step size to be used for each iteration of Gradient Descent.
* @param numIterations Number of iterations of gradient descent to run.
* @return a LinearRegressionModel which has the weights and offset from training.
*/
def train(
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double)
: LinearRegressionModel =
{
train(input, numIterations, stepSize, 1.0)
}
/**
* Train a LinearRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
* update the gradient in each iteration.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
* @return a LinearRegressionModel which has the weights and offset from training.
*/
def train(
input: RDD[LabeledPoint],
numIterations: Int)
: LinearRegressionModel =
{
train(input, numIterations, 1.0, 1.0)
}
def main(args: Array[String]) {
if (args.length != 5) {
println("Usage: LinearRegression <master> <input_dir> <step_size> <niters>")
System.exit(1)
}
val sc = new SparkContext(args(0), "LinearRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
sc.stop()
}
}
......@@ -18,200 +18,198 @@
package spark.mllib.regression
import spark.{Logging, RDD, SparkContext}
import spark.mllib.optimization._
import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix
import org.jblas.Solve
import scala.annotation.tailrec
import scala.collection.mutable
/**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
* Regression model trained using RidgeRegression.
*
* @param weights Weights computed for every feature.
* @param intercept Intercept computed for this model.
*/
class RidgeRegressionModel(
val weights: DoubleMatrix,
val intercept: Double,
val lambdaOpt: Double,
val lambdas: Seq[(Double, Double, DoubleMatrix)])
extends RegressionModel {
override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
// A small optimization to avoid serializing the entire model.
val localIntercept = this.intercept
val localWeights = this.weights
testData.map { x =>
(new DoubleMatrix(1, x.length, x:_*).mmul(localWeights)).get(0) + localIntercept
}
}
override def predict(testData: Array[Double]): Double = {
(new DoubleMatrix(1, testData.length, testData:_*).mmul(this.weights)).get(0) + this.intercept
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
dataMatrix.dot(weightMatrix) + intercept
}
}
class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
extends Logging {
def this() = this(0.0, 100.0)
/**
* Train a regression model with L2-regularization using Stochastic Gradient Descent.
*/
class RidgeRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
extends GeneralizedLinearAlgorithm[RidgeRegressionModel]
with Serializable {
val gradient = new SquaredGradient()
val updater = new SquaredL2Updater()
@transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
setIntercept(false)
var yMean = 0.0
var xColMean: DoubleMatrix = _
var xColSd: DoubleMatrix = _
/**
* Set the lower bound on binary search for lambda's. Default is 0.
* Construct a RidgeRegression object with default parameters
*/
def setLowLambda(low: Double) = {
this.lambdaLow = low
this
}
def this() = this(1.0, 100, 1.0, 1.0, true)
/**
* Set the upper bound on binary search for lambda's. Default is 100.0.
*/
def setHighLambda(hi: Double) = {
this.lambdaHigh = hi
this
def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
new RidgeRegressionModel(weightsScaled.data, interceptScaled)
}
def train(inputLabeled: RDD[LabeledPoint]): RidgeRegressionModel = {
val input = inputLabeled.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
val nfeatures: Int = input.take(1)(0)._2.length
override def run(
input: RDD[LabeledPoint],
initialWeights: Array[Double])
: RidgeRegressionModel =
{
val nfeatures: Int = input.first.features.length
val nexamples: Long = input.count()
val (yMean, xColMean, xColSd) = MLUtils.computeStats(input, nfeatures, nexamples)
// To avoid penalizing the intercept, we center and scale the data.
val stats = MLUtils.computeStats(input, nfeatures, nexamples)
yMean = stats._1
xColMean = stats._2
xColSd = stats._3
val data = input.map { case(y, features) =>
val yNormalized = y - yMean
val featuresMat = new DoubleMatrix(nfeatures, 1, features:_*)
val normalizedData = input.map { point =>
val yNormalized = point.label - yMean
val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*)
val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd)
(yNormalized, featuresNormalized.toArray)
LabeledPoint(yNormalized, featuresNormalized.toArray)
}
// Compute XtX - Size of XtX is nfeatures by nfeatures
val XtX: DoubleMatrix = data.map { case (y, features) =>
val x = new DoubleMatrix(1, features.length, features:_*)
x.transpose().mmul(x)
}.reduce(_.addi(_))
// Compute Xt*y - Size of Xty is nfeatures by 1
val Xty: DoubleMatrix = data.map { case (y, features) =>
new DoubleMatrix(features.length, 1, features:_*).mul(y)
}.reduce(_.addi(_))
// Define a function to compute the leave one out cross validation error
// for a single example
def crossValidate(lambda: Double): (Double, Double, DoubleMatrix) = {
// Compute the MLE ridge regression parameter value
// Ridge Regression parameter = inv(XtX + \lambda*I) * Xty
val XtXlambda = DoubleMatrix.eye(nfeatures).muli(lambda).addi(XtX)
val w = Solve.solveSymmetric(XtXlambda, Xty)
val invXtX = Solve.solveSymmetric(XtXlambda, DoubleMatrix.eye(nfeatures))
// compute the generalized cross validation score
val cverror = data.map {
case (y, features) =>
val x = new DoubleMatrix(features.length, 1, features:_*)
val yhat = w.transpose().mmul(x).get(0)
val H_ii = x.transpose().mmul(invXtX).mmul(x).get(0)
val residual = (y - yhat) / (1.0 - H_ii)
residual * residual
}.reduce(_ + _) / nexamples
(lambda, cverror, w)
}
// Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
val buffer = mutable.ListBuffer.empty[(Double, Double, DoubleMatrix)]
@tailrec
def loop(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
(low, mid + (high-low)/4)
} else {
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
buffer += lowValue += highValue
loop(newLow, newHigh)
} else {
buffer += lowValue += highValue
buffer.result()
}
}
loop(low, high)
}
// Actually compute the best lambda
val lambdas = binSearch(lambdaLow, lambdaHigh).sortBy(_._1)
// Find the best parameter set by taking the lowest cverror.
val (lambdaOpt, cverror, weights) = lambdas.reduce((a, b) => if (a._2 < b._2) a else b)
// Return the model which contains the solution
val weightsScaled = weights.div(xColSd)
val intercept = yMean - (weights.transpose().mmul(xColMean.div(xColSd)).get(0))
val model = new RidgeRegressionModel(weightsScaled, intercept, lambdaOpt, lambdas)
logInfo("RidgeRegression: optimal lambda " + model.lambdaOpt)
logInfo("RidgeRegression: optimal weights " + model.weights)
logInfo("RidgeRegression: optimal intercept " + model.intercept)
logInfo("RidgeRegression: cross-validation error " + cverror)
model
super.run(normalizedData, initialWeights)
}
}
/**
* Top-level methods for calling Ridge Regression.
* Top-level methods for calling RidgeRegression.
*/
object RidgeRegression {
// NOTE(shivaram): We use multiple train methods instead of default arguments to support
// Java programs.
object RidgeRegressionWithSGD {
/**
* Train a ridge regression model given an RDD of (response, features) pairs.
* We use the closed form solution to compute the cross-validation score for
* a given lambda. The optimal lambda is computed by performing binary search
* between the provided bounds of lambda.
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
* gradient descent are initialized using the initial weights provided.
*
* @param input RDD of (response, array of features) pairs.
* @param lambdaLow lower bound used in binary search for lambda
* @param lambdaHigh upper bound used in binary search for lambda
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data.
*/
def train(
input: RDD[LabeledPoint],
lambdaLow: Double,
lambdaHigh: Double)
numIterations: Int,
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
initialWeights: Array[Double])
: RidgeRegressionModel =
{
new RidgeRegression(lambdaLow, lambdaHigh).train(input)
new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(
input, initialWeights)
}
/**
* Train a ridge regression model given an RDD of (response, features) pairs.
* We use the closed form solution to compute the cross-validation score for
* a given lambda. The optimal lambda is computed by performing binary search
* between lambda values of 0 and 100.
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient.
*
* @param input RDD of (response, array of features) pairs.
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
*/
def train(input: RDD[LabeledPoint]) : RidgeRegressionModel = {
train(input, 0.0, 100.0)
def train(
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double,
miniBatchFraction: Double)
: RidgeRegressionModel =
{
new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(
input)
}
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
* update the gradient in each iteration.
*
* @param input RDD of (label, array of features) pairs.
* @param stepSize Step size to be used for each iteration of Gradient Descent.
* @param regParam Regularization parameter.
* @param numIterations Number of iterations of gradient descent to run.
* @return a RidgeRegressionModel which has the weights and offset from training.
*/
def train(
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double)
: RidgeRegressionModel =
{
train(input, numIterations, stepSize, regParam, 1.0)
}
/**
* Train a RidgeRegression model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
* update the gradient in each iteration.
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
* @return a RidgeRegressionModel which has the weights and offset from training.
*/
def train(
input: RDD[LabeledPoint],
numIterations: Int)
: RidgeRegressionModel =
{
train(input, numIterations, 1.0, 1.0, 1.0)
}
def main(args: Array[String]) {
if (args.length != 2) {
println("Usage: RidgeRegression <master> <input_dir>")
if (args.length != 5) {
println("Usage: RidgeRegression <master> <input_dir> <step_size> <regularization_parameter>" +
" <niters>")
System.exit(1)
}
val sc = new SparkContext(args(0), "RidgeRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = RidgeRegression.train(data, 0, 1000)
val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
args(3).toDouble)
sc.stop()
}
}
package spark.mllib.util
import scala.util.Random
import org.jblas.DoubleMatrix
import spark.{RDD, SparkContext}
import spark.mllib.regression.LabeledPoint
/**
* Generate sample data used for Lasso Regression. This class generates uniform random values
* for the features and adds Gaussian noise with weight 0.1 to generate response variables.
*/
object LassoDataGenerator {
def main(args: Array[String]) {
if (args.length < 2) {
println("Usage: LassoGenerator " +
"<master> <output_dir> [num_examples] [num_features] [num_partitions]")
System.exit(1)
}
val sparkMaster: String = args(0)
val outputPath: String = args(1)
val nexamples: Int = if (args.length > 2) args(2).toInt else 1000
val nfeatures: Int = if (args.length > 3) args(3).toInt else 2
val parts: Int = if (args.length > 4) args(4).toInt else 2
val sc = new SparkContext(sparkMaster, "LassoGenerator")
val globalRnd = new Random(94720)
val trueWeights = new DoubleMatrix(1, nfeatures+1,
Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }:_*)
val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx =>
val rnd = new Random(42 + idx)
val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0
}
val y = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1
LabeledPoint(y, x)
}
MLUtils.saveLabeledData(data, outputPath)
sc.stop()
}
}
......@@ -17,66 +17,101 @@
package spark.mllib.util
import scala.collection.JavaConversions._
import scala.util.Random
import org.jblas.DoubleMatrix
import spark.{RDD, SparkContext}
import spark.mllib.regression.LabeledPoint
import spark.mllib.regression.LabeledPoint
/**
* Generate sample data used for RidgeRegression. This class generates
* Generate sample data used for Linear Data. This class generates
* uniformly random values for every feature and adds Gaussian noise with mean `eps` to the
* response variable `Y`.
*
*/
object RidgeRegressionDataGenerator {
object LinearDataGenerator {
/**
* Return a Java List of synthetic data randomly generated according to a multi
* collinear model.
* @param intercept Data intercept
* @param weights Weights to be applied.
* @param nPoints Number of points in sample.
* @param seed Random seed
* @return Java List of input.
*/
def generateLinearInputAsList(
intercept: Double,
weights: Array[Double],
nPoints: Int,
seed: Int,
eps: Double): java.util.List[LabeledPoint] = {
seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps))
}
/**
* Generate an RDD containing sample data for RidgeRegression.
*
* @param intercept Data intercept
* @param weights Weights to be applied.
* @param nPoints Number of points in sample.
* @param seed Random seed
* @param eps Epsilon scaling factor.
* @return
*/
def generateLinearInput(
intercept: Double,
weights: Array[Double],
nPoints: Int,
seed: Int,
eps: Double = 0.1): Seq[LabeledPoint] = {
val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(
Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0))
val y = x.map { xi =>
(new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + eps * rnd.nextGaussian()
}
y.zip(x).map(p => LabeledPoint(p._1, p._2))
}
/**
* Generate an RDD containing sample data for Linear Regression models - including Ridge, Lasso,
* and uregularized variants.
*
* @param sc SparkContext to be used for generating the RDD.
* @param nexamples Number of examples that will be contained in the RDD.
* @param nfeatures Number of features to generate for each example.
* @param eps Epsilon factor by which examples are scaled.
* @param weights Weights associated with the first weights.length features.
* @param nparts Number of partitions in the RDD. Default value is 2.
*
* @return RDD of LabeledPoint containing sample data.
*/
def generateRidgeRDD(
sc: SparkContext,
nexamples: Int,
nfeatures: Int,
eps: Double,
nparts: Int = 2) : RDD[LabeledPoint] = {
def generateLinearRDD(
sc: SparkContext,
nexamples: Int,
nfeatures: Int,
eps: Double,
nparts: Int = 2,
intercept: Double = 0.0) : RDD[LabeledPoint] = {
org.jblas.util.Random.seed(42)
// Random values distributed uniformly in [-0.5, 0.5]
val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
w.put(0, 0, 10)
w.put(1, 0, 10)
val data: RDD[LabeledPoint] = sc.parallelize(0 until nparts, nparts).flatMap { p =>
org.jblas.util.Random.seed(42 + p)
val seed = 42 + p
val examplesInPartition = nexamples / nparts
val X = DoubleMatrix.rand(examplesInPartition, nfeatures)
val y = X.mmul(w)
val rnd = new Random(42 + p)
val normalValues = Array.fill[Double](examplesInPartition)(rnd.nextGaussian() * eps)
val yObs = new DoubleMatrix(normalValues).addi(y)
Iterator.tabulate(examplesInPartition) { i =>
LabeledPoint(yObs.get(i, 0), X.getRow(i).toArray)
}
generateLinearInput(intercept, w.toArray, examplesInPartition, seed, eps)
}
data
}
def main(args: Array[String]) {
if (args.length < 2) {
println("Usage: RidgeRegressionGenerator " +
println("Usage: LinearDataGenerator " +
"<master> <output_dir> [num_examples] [num_features] [num_partitions]")
System.exit(1)
}
......@@ -88,8 +123,8 @@ object RidgeRegressionDataGenerator {
val parts: Int = if (args.length > 4) args(4).toInt else 2
val eps = 10
val sc = new SparkContext(sparkMaster, "RidgeRegressionDataGenerator")
val data = generateRidgeRDD(sc, nexamples, nfeatures, eps, parts)
val sc = new SparkContext(sparkMaster, "LinearDataGenerator")
val data = generateLinearRDD(sc, nexamples, nfeatures, eps, nparts = parts)
MLUtils.saveLabeledData(data, outputPath)
sc.stop()
......
......@@ -72,16 +72,16 @@ object MLUtils {
* xColMean - Row vector with mean for every column (or feature) of the input data
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
*/
def computeStats(data: RDD[(Double, Array[Double])], nfeatures: Int, nexamples: Long):
def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long):
(Double, DoubleMatrix, DoubleMatrix) = {
val yMean: Double = data.map { case (y, features) => y }.reduce(_ + _) / nexamples
val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples
// NOTE: We shuffle X by column here to compute column sum and sum of squares.
val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { case(y, features) =>
val nCols = features.length
val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint =>
val nCols = labeledPoint.features.length
// Traverse over every column and emit (col, value, value^2)
Iterator.tabulate(nCols) { i =>
(i, (features(i), features(i)*features(i)))
(i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i)))
}
}.reduceByKey { case(x1, x2) =>
(x1._1 + x2._1, x1._2 + x2._2)
......
......@@ -27,6 +27,7 @@ import org.junit.Test;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.mllib.util.LinearDataGenerator;
public class JavaLassoSuite implements Serializable {
private transient JavaSparkContext sc;
......@@ -61,16 +62,16 @@ public class JavaLassoSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LassoSuite.generateLassoInputAsList(A,
weights, nPoints, 42), 2).cache();
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
LassoSuite.generateLassoInputAsList(A, weights, nPoints, 17);
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoWithSGD svmSGDImpl = new LassoWithSGD();
svmSGDImpl.optimizer().setStepSize(1.0)
LassoWithSGD lassoSGDImpl = new LassoWithSGD();
lassoSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(0.01)
.setNumIterations(20);
LassoModel model = svmSGDImpl.run(testRDD.rdd());
LassoModel model = lassoSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
......@@ -82,10 +83,10 @@ public class JavaLassoSuite implements Serializable {
double A = 2.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LassoSuite.generateLassoInputAsList(A,
weights, nPoints, 42), 2).cache();
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
LassoSuite.generateLassoInputAsList(A, weights, nPoints, 17);
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.mllib.regression;
import java.io.Serializable;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.mllib.util.LinearDataGenerator;
public class JavaLinearRegressionSuite implements Serializable {
private transient JavaSparkContext sc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
}
@After
public void tearDown() {
sc.stop();
sc = null;
System.clearProperty("spark.driver.port");
}
int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
int numAccurate = 0;
for (LabeledPoint point: validationData) {
Double prediction = model.predict(point.features());
// A prediction is off if the prediction is more than 0.5 away from expected value.
if (Math.abs(prediction - point.label()) <= 0.5) {
numAccurate++;
}
}
return numAccurate;
}
@Test
public void runLinearRegressionUsingConstructor() {
int nPoints = 100;
double A = 3.0;
double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
@Test
public void runLinearRegressionUsingStaticMethods() {
int nPoints = 100;
double A = 3.0;
double[] weights = {10, 10};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
List<LabeledPoint> validationData =
LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100);
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.mllib.regression;
import java.io.Serializable;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.jblas.DoubleMatrix;
import spark.api.java.JavaRDD;
import spark.api.java.JavaSparkContext;
import spark.mllib.util.LinearDataGenerator;
public class JavaRidgeRegressionSuite implements Serializable {
private transient JavaSparkContext sc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
}
@After
public void tearDown() {
sc.stop();
sc = null;
System.clearProperty("spark.driver.port");
}
double predictionError(List<LabeledPoint> validationData, RidgeRegressionModel model) {
double errorSum = 0;
for (LabeledPoint point: validationData) {
Double prediction = model.predict(point.features());
errorSum += (prediction - point.label()) * (prediction - point.label());
}
return errorSum / validationData.size();
}
List<LabeledPoint> generateRidgeData(int numPoints, int nfeatures, double eps) {
org.jblas.util.Random.seed(42);
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5);
// Set first two weights to eps
w.put(0, 0, eps);
w.put(1, 0, eps);
return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps);
}
@Test
public void runRidgeRegressionUsingConstructor() {
int nexamples = 200;
int nfeatures = 20;
double eps = 10.0;
List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
ridgeSGDImpl.optimizer().setStepSize(1.0)
.setRegParam(0.0)
.setNumIterations(200);
RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
double unRegularizedErr = predictionError(validationData, model);
ridgeSGDImpl.optimizer().setRegParam(0.1);
model = ridgeSGDImpl.run(testRDD.rdd());
double regularizedErr = predictionError(validationData, model);
Assert.assertTrue(regularizedErr < unRegularizedErr);
}
@Test
public void runRidgeRegressionUsingStaticMethods() {
int nexamples = 200;
int nfeatures = 20;
double eps = 10.0;
List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
double unRegularizedErr = predictionError(validationData, model);
model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.1);
double regularizedErr = predictionError(validationData, model);
Assert.assertTrue(regularizedErr < unRegularizedErr);
}
}
......@@ -24,37 +24,8 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
import spark.mllib.util.LinearDataGenerator
import org.jblas.DoubleMatrix
object LassoSuite {
def generateLassoInputAsList(
intercept: Double,
weights: Array[Double],
nPoints: Int,
seed: Int): java.util.List[LabeledPoint] = {
seqAsJavaList(generateLassoInput(intercept, weights, nPoints, seed))
}
// Generate noisy input of the form Y = x.dot(weights) + intercept + noise
def generateLassoInput(
intercept: Double,
weights: Array[Double],
nPoints: Int,
seed: Int): Seq[LabeledPoint] = {
val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(
Array.fill[Double](weights.length)(rnd.nextGaussian()))
val y = x.map(xi =>
(new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + 0.1 * rnd.nextGaussian()
)
y.zip(x).map(p => LabeledPoint(p._1, p._2))
}
}
class LassoSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
......@@ -85,7 +56,7 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
val B = -1.5
val C = 1.0e-2
val testData = LassoSuite.generateLassoInput(A, Array[Double](B,C), nPoints, 42)
val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
......@@ -101,7 +72,7 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
val validationData = LassoSuite.generateLassoInput(A, Array[Double](B,C), nPoints, 17)
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
......@@ -118,7 +89,7 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
val B = -1.5
val C = 1.0e-2
val testData = LassoSuite.generateLassoInput(A, Array[Double](B,C), nPoints, 42)
val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)
val initialB = -1.0
val initialC = -1.0
......@@ -138,7 +109,7 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
val validationData = LassoSuite.generateLassoInput(A, Array[Double](B,C), nPoints, 17)
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
val validationRDD = sc.parallelize(validationData,2)
// Test prediction on RDD.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.mllib.regression
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
import spark.SparkContext._
import spark.mllib.util.LinearDataGenerator
class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) =>
// A prediction is off if the prediction is more than 0.5 away from expected value.
math.abs(prediction - expected.label) > 0.5
}.size
// At least 80% of the predictions should be on.
assert(numOffPredictions < input.length / 5)
}
// Test if we can correctly learn Y = 3 + 10*X1 + 10*X2
test("linear regression") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 42), 2).cache()
val linReg = new LinearRegressionWithSGD()
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
val model = linReg.run(testRDD)
assert(model.intercept >= 2.5 && model.intercept <= 3.5)
assert(model.weights.length === 2)
assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
val validationData = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 17)
val validationRDD = sc.parallelize(validationData, 2).cache()
// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
......@@ -17,14 +17,16 @@
package spark.mllib.regression
import scala.collection.JavaConversions._
import scala.util.Random
import org.jblas.DoubleMatrix
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
import spark.SparkContext._
import spark.mllib.util.LinearDataGenerator
class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
......@@ -38,31 +40,51 @@ class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
System.clearProperty("spark.driver.port")
}
// Test if we can correctly learn Y = 3 + X1 + X2 when
// X1 and X2 are collinear.
test("multi-collinear variables") {
val rnd = new Random(43)
val x1 = Array.fill[Double](20)(rnd.nextGaussian())
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
predictions.zip(input).map { case (prediction, expected) =>
(prediction - expected.label) * (prediction - expected.label)
}.reduceLeft(_ + _) / predictions.size
}
test("regularization with skewed weights") {
val nexamples = 200
val nfeatures = 20
val eps = 10
org.jblas.util.Random.seed(42)
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
// Set first two weights to eps
w.put(0, 0, eps)
w.put(1, 0, eps)
// Pick a mean close to mean of x1
val rnd1 = new Random(42) //new NormalDistribution(0.1, 0.01)
val x2 = Array.fill[Double](20)(0.1 + rnd1.nextGaussian() * 0.01)
// Use half of data for training and other half for validation
val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps)
val testData = data.take(nexamples)
val validationData = data.takeRight(nexamples)
val xMat = (0 until 20).map(i => Array(x1(i), x2(i))).toArray
val testRDD = sc.parallelize(testData, 2).cache()
val validationRDD = sc.parallelize(validationData, 2).cache()
val y = xMat.map(i => 3 + i(0) + i(1))
val testData = (0 until 20).map(i => LabeledPoint(y(i), xMat(i))).toArray
// First run without regularization.
val linearReg = new LinearRegressionWithSGD()
linearReg.optimizer.setNumIterations(200)
.setStepSize(1.0)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val ridgeReg = new RidgeRegression().setLowLambda(0)
.setHighLambda(10)
val linearModel = linearReg.run(testRDD)
val linearErr = predictionError(
linearModel.predict(validationRDD.map(_.features)).collect(), validationData)
val model = ridgeReg.train(testRDD)
val ridgeReg = new RidgeRegressionWithSGD()
ridgeReg.optimizer.setNumIterations(200)
.setRegParam(0.1)
.setStepSize(1.0)
val ridgeModel = ridgeReg.run(testRDD)
val ridgeErr = predictionError(
ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)
assert(model.intercept >= 2.9 && model.intercept <= 3.1)
assert(model.weights.length === 2)
assert(model.weights.get(0) >= 0.9 && model.weights.get(0) <= 1.1)
assert(model.weights.get(1) >= 0.9 && model.weights.get(1) <= 1.1)
// Ridge CV-error should be lower than linear regression
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment