Skip to content
Snippets Groups Projects
Commit 3bf98971 authored by Shivaram Venkataraman's avatar Shivaram Venkataraman
Browse files

Rename loss -> stochasticLoss and add a note to explain why we have

multiple train methods.
parent 84fa20c2
No related branches found
No related tags found
No related merge requests found
...@@ -40,7 +40,8 @@ object GradientDescent { ...@@ -40,7 +40,8 @@ object GradientDescent {
* one iteration of SGD. Default value 1.0. * one iteration of SGD. Default value 1.0.
* *
* @return weights - Column matrix containing weights for every feature. * @return weights - Column matrix containing weights for every feature.
* @return lossHistory - Array containing the loss computed for every iteration. * @return stochasticLossHistory - Array containing the stochastic loss computed for
* every iteration.
*/ */
def runMiniBatchSGD( def runMiniBatchSGD(
data: RDD[(Double, Array[Double])], data: RDD[(Double, Array[Double])],
...@@ -51,7 +52,7 @@ object GradientDescent { ...@@ -51,7 +52,7 @@ object GradientDescent {
initialWeights: Array[Double], initialWeights: Array[Double],
miniBatchFraction: Double=1.0) : (DoubleMatrix, Array[Double]) = { miniBatchFraction: Double=1.0) : (DoubleMatrix, Array[Double]) = {
val lossHistory = new ArrayBuffer[Double](numIters) val stochasticLossHistory = new ArrayBuffer[Double](numIters)
val nexamples: Long = data.count() val nexamples: Long = data.count()
val miniBatchSize = nexamples * miniBatchFraction val miniBatchSize = nexamples * miniBatchFraction
...@@ -69,12 +70,12 @@ object GradientDescent { ...@@ -69,12 +70,12 @@ object GradientDescent {
(grad, loss) (grad, loss)
}.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2)) }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2))
lossHistory.append(lossSum / miniBatchSize + reg_val) stochasticLossHistory.append(lossSum / miniBatchSize + reg_val)
val update = updater.compute(weights, gradientSum.div(miniBatchSize), stepSize, i) val update = updater.compute(weights, gradientSum.div(miniBatchSize), stepSize, i)
weights = update._1 weights = update._1
reg_val = update._2 reg_val = update._2
} }
(weights, lossHistory.toArray) (weights, stochasticLossHistory.toArray)
} }
} }
...@@ -30,7 +30,7 @@ import org.jblas.DoubleMatrix ...@@ -30,7 +30,7 @@ import org.jblas.DoubleMatrix
class LogisticRegressionModel( class LogisticRegressionModel(
val weights: DoubleMatrix, val weights: DoubleMatrix,
val intercept: Double, val intercept: Double,
val losses: Array[Double]) extends RegressionModel { val stochasticLosses: Array[Double]) extends RegressionModel {
override def predict(testData: spark.RDD[Array[Double]]) = { override def predict(testData: spark.RDD[Array[Double]]) = {
testData.map { x => testData.map { x =>
...@@ -114,7 +114,7 @@ class LogisticRegression private (var stepSize: Double, var miniBatchFraction: D ...@@ -114,7 +114,7 @@ class LogisticRegression private (var stepSize: Double, var miniBatchFraction: D
val initalWeightsWithIntercept = Array(1.0, initialWeights:_*) val initalWeightsWithIntercept = Array(1.0, initialWeights:_*)
val (weights, losses) = GradientDescent.runMiniBatchSGD( val (weights, stochasticLosses) = GradientDescent.runMiniBatchSGD(
data, data,
new LogisticGradient(), new LogisticGradient(),
new SimpleUpdater(), new SimpleUpdater(),
...@@ -126,17 +126,19 @@ class LogisticRegression private (var stepSize: Double, var miniBatchFraction: D ...@@ -126,17 +126,19 @@ class LogisticRegression private (var stepSize: Double, var miniBatchFraction: D
val weightsScaled = weights.getRange(1, weights.length) val weightsScaled = weights.getRange(1, weights.length)
val intercept = weights.get(0) val intercept = weights.get(0)
val model = new LogisticRegressionModel(weightsScaled, intercept, losses) val model = new LogisticRegressionModel(weightsScaled, intercept, stochasticLosses)
logInfo("Final model weights " + model.weights) logInfo("Final model weights " + model.weights)
logInfo("Final model intercept " + model.intercept) logInfo("Final model intercept " + model.intercept)
logInfo("Last 10 losses " + model.losses.takeRight(10).mkString(", ")) logInfo("Last 10 stochastic losses " + model.stochasticLosses.takeRight(10).mkString(", "))
model model
} }
} }
/** /**
* Top-level methods for calling Logistic Regression. * Top-level methods for calling Logistic Regression.
* NOTE(shivaram): We use multiple train methods instead of default arguments to support
* Java programs.
*/ */
object LogisticRegression { object LogisticRegression {
......
...@@ -164,6 +164,8 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) ...@@ -164,6 +164,8 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
/** /**
* Top-level methods for calling Ridge Regression. * Top-level methods for calling Ridge Regression.
* NOTE(shivaram): We use multiple train methods instead of default arguments to support
* Java programs.
*/ */
object RidgeRegression { object RidgeRegression {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment