Skip to content
Snippets Groups Projects
Commit 1db1c656 authored by sethah's avatar sethah Committed by DB Tsai
Browse files

[SPARK-16404][ML] LeastSquaresAggregators serializes unnecessary data

## What changes were proposed in this pull request?
Similar to `LogisticAggregator`, `LeastSquaresAggregator` used for linear regression ends up serializing the coefficients and the features standard deviations, which is not necessary and can cause performance issues for high dimensional data. This patch removes this serialization.

In https://github.com/apache/spark/pull/13729 the approach was to pass these values directly to the add method. The approach used here, initially, is to mark these fields as transient instead which gives the benefit of keeping the signature of the add method simple and interpretable. The downside is that it requires the use of `transient lazy val`s which are difficult to reason about if one is not quite familiar with serialization in Scala/Spark.

## How was this patch tested?

**MLlib**
![image](https://cloud.githubusercontent.com/assets/7275795/16703660/436f79fa-4524-11e6-9022-ef00058ec718.png)

**ML without patch**
![image](https://cloud.githubusercontent.com/assets/7275795/16703831/c4d50b9e-4525-11e6-80cb-9b58c850cd41.png)

**ML with patch**
![image](https://cloud.githubusercontent.com/assets/7275795/16703675/63e0cf40-4524-11e6-9120-1f512a70e083.png)

Author: sethah <seth.hendrickson16@gmail.com>

Closes #14109 from sethah/LIR_serialize.
parent e076fb05
No related branches found
No related tags found
No related merge requests found
...@@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path ...@@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.{Vector, Vectors}
...@@ -82,6 +83,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -82,6 +83,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
/** /**
* Set the regularization parameter. * Set the regularization parameter.
* Default is 0.0. * Default is 0.0.
*
* @group setParam * @group setParam
*/ */
@Since("1.3.0") @Since("1.3.0")
...@@ -91,6 +93,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -91,6 +93,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
/** /**
* Set if we should fit the intercept * Set if we should fit the intercept
* Default is true. * Default is true.
*
* @group setParam * @group setParam
*/ */
@Since("1.5.0") @Since("1.5.0")
...@@ -104,6 +107,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -104,6 +107,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
* the models should be always converged to the same solution when no regularization * the models should be always converged to the same solution when no regularization
* is applied. In R's GLMNET package, the default behavior is true as well. * is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true. * Default is true.
*
* @group setParam * @group setParam
*/ */
@Since("1.5.0") @Since("1.5.0")
...@@ -115,6 +119,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -115,6 +119,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
* For 0 < alpha < 1, the penalty is a combination of L1 and L2. * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
* Default is 0.0 which is an L2 penalty. * Default is 0.0 which is an L2 penalty.
*
* @group setParam * @group setParam
*/ */
@Since("1.4.0") @Since("1.4.0")
...@@ -124,6 +129,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -124,6 +129,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
/** /**
* Set the maximum number of iterations. * Set the maximum number of iterations.
* Default is 100. * Default is 100.
*
* @group setParam * @group setParam
*/ */
@Since("1.3.0") @Since("1.3.0")
...@@ -134,6 +140,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -134,6 +140,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
* Set the convergence tolerance of iterations. * Set the convergence tolerance of iterations.
* Smaller value will lead to higher accuracy with the cost of more iterations. * Smaller value will lead to higher accuracy with the cost of more iterations.
* Default is 1E-6. * Default is 1E-6.
*
* @group setParam * @group setParam
*/ */
@Since("1.4.0") @Since("1.4.0")
...@@ -144,6 +151,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -144,6 +151,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
* Whether to over-/under-sample training instances according to the given weights in weightCol. * Whether to over-/under-sample training instances according to the given weights in weightCol.
* If not set or empty, all instances are treated equally (weight 1.0). * If not set or empty, all instances are treated equally (weight 1.0).
* Default is not set, so all instances have weight one. * Default is not set, so all instances have weight one.
*
* @group setParam * @group setParam
*/ */
@Since("1.6.0") @Since("1.6.0")
...@@ -157,6 +165,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -157,6 +165,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
* solution to the linear regression problem. * solution to the linear regression problem.
* The default value is "auto" which means that the solver algorithm is * The default value is "auto" which means that the solver algorithm is
* selected automatically. * selected automatically.
*
* @group setParam * @group setParam
*/ */
@Since("1.6.0") @Since("1.6.0")
...@@ -270,6 +279,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -270,6 +279,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean) val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean)
val featuresMean = featuresSummarizer.mean.toArray val featuresMean = featuresSummarizer.mean.toArray
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
val bcFeaturesMean = instances.context.broadcast(featuresMean)
val bcFeaturesStd = instances.context.broadcast(featuresStd)
if (!$(fitIntercept) && (0 until numFeatures).exists { i => if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
...@@ -285,7 +296,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -285,7 +296,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
$(standardization), featuresStd, featuresMean, effectiveL2RegParam) $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam)
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
...@@ -330,6 +341,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String ...@@ -330,6 +341,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
throw new SparkException(msg) throw new SparkException(msg)
} }
bcFeaturesMean.destroy(blocking = false)
bcFeaturesStd.destroy(blocking = false)
/* /*
The coefficients are trained in the scaled space; we're converting them back to The coefficients are trained in the scaled space; we're converting them back to
the original space. the original space.
...@@ -419,6 +433,7 @@ class LinearRegressionModel private[ml] ( ...@@ -419,6 +433,7 @@ class LinearRegressionModel private[ml] (
/** /**
* Evaluates the model on a test dataset. * Evaluates the model on a test dataset.
*
* @param dataset Test dataset to evaluate model on. * @param dataset Test dataset to evaluate model on.
*/ */
@Since("2.0.0") @Since("2.0.0")
...@@ -544,6 +559,7 @@ class LinearRegressionTrainingSummary private[regression] ( ...@@ -544,6 +559,7 @@ class LinearRegressionTrainingSummary private[regression] (
* Number of training iterations until termination * Number of training iterations until termination
* *
* This value is only available when using the "l-bfgs" solver. * This value is only available when using the "l-bfgs" solver.
*
* @see [[LinearRegression.solver]] * @see [[LinearRegression.solver]]
*/ */
@Since("1.5.0") @Since("1.5.0")
...@@ -862,27 +878,31 @@ class LinearRegressionSummary private[regression] ( ...@@ -862,27 +878,31 @@ class LinearRegressionSummary private[regression] (
* $$ * $$
* </blockquote></p> * </blockquote></p>
* *
* @param coefficients The coefficients corresponding to the features. * @param bcCoefficients The broadcast coefficients corresponding to the features.
* @param labelStd The standard deviation value of the label. * @param labelStd The standard deviation value of the label.
* @param labelMean The mean value of the label. * @param labelMean The mean value of the label.
* @param fitIntercept Whether to fit an intercept term. * @param fitIntercept Whether to fit an intercept term.
* @param featuresStd The standard deviation values of the features. * @param bcFeaturesStd The broadcast standard deviation values of the features.
* @param featuresMean The mean values of the features. * @param bcFeaturesMean The broadcast mean values of the features.
*/ */
private class LeastSquaresAggregator( private class LeastSquaresAggregator(
coefficients: Vector, bcCoefficients: Broadcast[Vector],
labelStd: Double, labelStd: Double,
labelMean: Double, labelMean: Double,
fitIntercept: Boolean, fitIntercept: Boolean,
featuresStd: Array[Double], bcFeaturesStd: Broadcast[Array[Double]],
featuresMean: Array[Double]) extends Serializable { bcFeaturesMean: Broadcast[Array[Double]]) extends Serializable {
private var totalCnt: Long = 0L private var totalCnt: Long = 0L
private var weightSum: Double = 0.0 private var weightSum: Double = 0.0
private var lossSum = 0.0 private var lossSum = 0.0
private val (effectiveCoefficientsArray: Array[Double], offset: Double, dim: Int) = { private val dim = bcCoefficients.value.size
val coefficientsArray = coefficients.toArray.clone() // make transient so we do not serialize between aggregation stages
@transient private lazy val featuresStd = bcFeaturesStd.value
@transient private lazy val effectiveCoefAndOffset = {
val coefficientsArray = bcCoefficients.value.toArray.clone()
val featuresMean = bcFeaturesMean.value
var sum = 0.0 var sum = 0.0
var i = 0 var i = 0
val len = coefficientsArray.length val len = coefficientsArray.length
...@@ -896,10 +916,11 @@ private class LeastSquaresAggregator( ...@@ -896,10 +916,11 @@ private class LeastSquaresAggregator(
i += 1 i += 1
} }
val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0
(coefficientsArray, offset, coefficientsArray.length) (Vectors.dense(coefficientsArray), offset)
} }
// do not use tuple assignment above because it will circumvent the @transient tag
private val effectiveCoefficientsVector = Vectors.dense(effectiveCoefficientsArray) @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1
@transient private lazy val offset = effectiveCoefAndOffset._2
private val gradientSumArray = Array.ofDim[Double](dim) private val gradientSumArray = Array.ofDim[Double](dim)
...@@ -922,9 +943,10 @@ private class LeastSquaresAggregator( ...@@ -922,9 +943,10 @@ private class LeastSquaresAggregator(
if (diff != 0) { if (diff != 0) {
val localGradientSumArray = gradientSumArray val localGradientSumArray = gradientSumArray
val localFeaturesStd = featuresStd
features.foreachActive { (index, value) => features.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) { if (localFeaturesStd(index) != 0.0 && value != 0.0) {
localGradientSumArray(index) += weight * diff * value / featuresStd(index) localGradientSumArray(index) += weight * diff * value / localFeaturesStd(index)
} }
} }
lossSum += weight * diff * diff / 2.0 lossSum += weight * diff * diff / 2.0
...@@ -992,23 +1014,26 @@ private class LeastSquaresCostFun( ...@@ -992,23 +1014,26 @@ private class LeastSquaresCostFun(
labelMean: Double, labelMean: Double,
fitIntercept: Boolean, fitIntercept: Boolean,
standardization: Boolean, standardization: Boolean,
featuresStd: Array[Double], bcFeaturesStd: Broadcast[Array[Double]],
featuresMean: Array[Double], bcFeaturesMean: Broadcast[Array[Double]],
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val coeffs = Vectors.fromBreeze(coefficients) val coeffs = Vectors.fromBreeze(coefficients)
val bcCoeffs = instances.context.broadcast(coeffs)
val localFeaturesStd = bcFeaturesStd.value
val leastSquaresAggregator = { val leastSquaresAggregator = {
val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance) val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance)
val combOp = (c1: LeastSquaresAggregator, c2: LeastSquaresAggregator) => c1.merge(c2) val combOp = (c1: LeastSquaresAggregator, c2: LeastSquaresAggregator) => c1.merge(c2)
instances.treeAggregate( instances.treeAggregate(
new LeastSquaresAggregator(coeffs, labelStd, labelMean, fitIntercept, featuresStd, new LeastSquaresAggregator(bcCoeffs, labelStd, labelMean, fitIntercept, bcFeaturesStd,
featuresMean))(seqOp, combOp) bcFeaturesMean))(seqOp, combOp)
} }
val totalGradientArray = leastSquaresAggregator.gradient.toArray val totalGradientArray = leastSquaresAggregator.gradient.toArray
bcCoeffs.destroy(blocking = false)
val regVal = if (effectiveL2regParam == 0.0) { val regVal = if (effectiveL2regParam == 0.0) {
0.0 0.0
...@@ -1022,13 +1047,13 @@ private class LeastSquaresCostFun( ...@@ -1022,13 +1047,13 @@ private class LeastSquaresCostFun(
totalGradientArray(index) += effectiveL2regParam * value totalGradientArray(index) += effectiveL2regParam * value
value * value value * value
} else { } else {
if (featuresStd(index) != 0.0) { if (localFeaturesStd(index) != 0.0) {
// If `standardization` is false, we still standardize the data // If `standardization` is false, we still standardize the data
// to improve the rate of convergence; as a result, we have to // to improve the rate of convergence; as a result, we have to
// perform this reverse standardization by penalizing each component // perform this reverse standardization by penalizing each component
// differently to get effectively the same objective function when // differently to get effectively the same objective function when
// the training dataset is not standardized. // the training dataset is not standardized.
val temp = value / (featuresStd(index) * featuresStd(index)) val temp = value / (localFeaturesStd(index) * localFeaturesStd(index))
totalGradientArray(index) += effectiveL2regParam * temp totalGradientArray(index) += effectiveL2regParam * temp
value * temp value * temp
} else { } else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment