Skip to content
Snippets Groups Projects
Commit 5f1cee6f authored by Nakul Jindal's avatar Nakul Jindal Committed by DB Tsai
Browse files

[SPARK-11332] [ML] Refactored to use ml.feature.Instance instead of WeightedLeastSquare.Instance

WeightedLeastSquares now uses the common Instance class in ml.feature instead of a private one.

Author: Nakul Jindal <njindal@us.ibm.com>

Closes #9325 from nakul02/SPARK-11332_refactor_WeightedLeastSquares_dot_Instance.
parent 82c1c577
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@
package org.apache.spark.ml.optim
import org.apache.spark.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
......@@ -121,16 +122,6 @@ private[ml] class WeightedLeastSquares(
private[ml] object WeightedLeastSquares {
/**
* Case class for weighted observations.
* @param w weight, must be positive
* @param a features
* @param b label
*/
case class Instance(w: Double, a: Vector, b: Double) {
require(w >= 0.0, s"Weight cannot be negative: $w.")
}
/**
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
*/
......@@ -168,8 +159,8 @@ private[ml] object WeightedLeastSquares {
* Adds an instance.
*/
def add(instance: Instance): this.type = {
val Instance(w, a, b) = instance
val ak = a.size
val Instance(l, w, f) = instance
val ak = f.size
if (!initialized) {
init(ak)
}
......@@ -177,11 +168,11 @@ private[ml] object WeightedLeastSquares {
count += 1L
wSum += w
wwSum += w * w
bSum += w * b
bbSum += w * b * b
BLAS.axpy(w, a, aSum)
BLAS.axpy(w * b, a, abSum)
BLAS.spr(w, a, aaSum)
bSum += w * l
bbSum += w * l * l
BLAS.axpy(w, f, aSum)
BLAS.axpy(w * l, f, abSum)
BLAS.spr(w, f, aaSum)
this
}
......
......@@ -154,10 +154,10 @@ class LinearRegression(override val uid: String)
"solver is used.'")
// For low dimensional data, WeightedLeastSquares is more efficiently since the
// training algorithm only requires one pass through the data. (SPARK-10668)
val instances: RDD[WeightedLeastSquares.Instance] = dataset.select(
val instances: RDD[Instance] = dataset.select(
col($(labelCol)), w, col($(featuresCol))).map {
case Row(label: Double, weight: Double, features: Vector) =>
WeightedLeastSquares.Instance(weight, features, label)
Instance(label, weight, features)
}
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
......
......@@ -18,7 +18,7 @@
package org.apache.spark.ml.optim
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.optim.WeightedLeastSquares.Instance
import org.apache.spark.ml.feature.Instance
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
......@@ -38,10 +38,10 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
w <- c(1, 2, 3, 4)
*/
instances = sc.parallelize(Seq(
Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0),
Instance(2.0, Vectors.dense(1.0, 7.0), 19.0),
Instance(3.0, Vectors.dense(2.0, 11.0), 23.0),
Instance(4.0, Vectors.dense(3.0, 13.0), 29.0)
Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
), 2)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment