Skip to content
Snippets Groups Projects
Commit 52beb20f authored by DB Tsai's avatar DB Tsai Committed by Xiangrui Meng
Browse files

[SPARK-2477][MLlib] Using appendBias for adding intercept in GeneralizedLinearAlgorithm

Instead of using prependOne currently in GeneralizedLinearAlgorithm, we would like to use appendBias for 1) keeping the indices of original training set unchanged by adding the intercept into the last element of vector and 2) using the same public API for consistently adding intercept.

Author: DB Tsai <dbtsai@alpinenow.com>

Closes #1410 from dbtsai/SPARK-2477_intercept_with_appendBias and squashes the following commits:

011432c [DB Tsai] From Alpine Data Labs
parent dd95abad
No related branches found
No related tags found
No related merge requests found
......@@ -17,13 +17,12 @@
package org.apache.spark.mllib.regression
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.MLUtils._
/**
* :: DeveloperApi ::
......@@ -124,16 +123,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
run(input, initialWeights)
}
/** Prepends one to the input vector. */
private def prependOne(vector: Vector): Vector = {
val vector1 = vector.toBreeze match {
case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv)
case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv)
case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
Vectors.fromBreeze(vector1)
}
/**
* Run the algorithm with the configured parameters on an input RDD
* of LabeledPoint entries starting from the initial weights provided.
......@@ -147,23 +136,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
// Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features)))
input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
} else {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
}
val initialWeightsWithIntercept = if (addIntercept) {
prependOne(initialWeights)
appendBias(initialWeights)
} else {
initialWeights
}
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0
val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0
val weights =
if (addIntercept) {
Vectors.dense(weightsWithIntercept.toArray.slice(1, weightsWithIntercept.size))
Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
} else {
weightsWithIntercept
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment