From 1665b5f724b486068a62c9c72dfd7ed76807c1b3 Mon Sep 17 00:00:00 2001 From: sethah <seth.hendrickson16@gmail.com> Date: Mon, 5 Jun 2017 10:32:17 +0100 Subject: [PATCH] [SPARK-19762][ML] Hierarchy for consolidating ML aggregator/loss code ## What changes were proposed in this pull request? JIRA: [SPARK-19762](https://issues.apache.org/jira/browse/SPARK-19762) The larger changes in this patch are: * Adds a `DifferentiableLossAggregator` trait which is intended to be used as a common parent trait to all Spark ML aggregator classes. It factors out the common methods: `merge, gradient, loss, weight` from the aggregator subclasses. * Adds a `RDDLossFunction` which is intended to be the only implementation of Breeze's `DiffFunction` necessary in Spark ML, and can be used by all other algorithms. It takes the aggregator type as a type parameter, and maps the aggregator over an RDD. It additionally takes in a optional regularization loss function for applying the differentiable part of regularization. * Factors out the regularization from the data part of the cost function, and treats regularization as a separate independent cost function which can be evaluated and added to the data cost function. * Changes `LinearRegression` to use this new hierarchy as a proof of concept. * Adds the following new namespaces `o.a.s.ml.optim.loss` and `o.a.s.ml.optim.aggregator` Also note that none of these are public-facing changes. All of these classes are internal to Spark ML and remain that way. **NOTE: The large majority of the "lines added" and "lines deleted" are simply code moving around or unit tests.** BTW, I also converted LinearSVC to this framework as a way to prove that this new hierarchy is flexible enough for the other algorithms, but I backed those changes out because the PR is large enough as is. ## How was this patch tested? Test suites are added for the new components, and some test suites are also added to provide coverage where there wasn't any before. * DifferentiablLossAggregatorSuite * LeastSquaresAggregatorSuite * RDDLossFunctionSuite * DifferentiableRegularizationSuite Below are some performance testing numbers. Run on a 6 node virtual cluster with 44 cores and ~110G RAM, the dataset size is about 37G. These are not "large-scale" tests, but we really want to just make sure the iteration times don't increase with this patch. Notably we are doing the regularization a bit differently than before, but that should cost very little. I think there's very little risk otherwise, and these numbers don't show a difference. Of course I'm happy to add more tests as we think it's necessary, but I think the patch is ready for review now. **Note:** timings are best of 3 runs. | | numFeatures | numPoints | maxIter | regParam | elasticNetParam | SPARK-19762 (sec) | master (sec) | |----|---------------|-------------|-----------|------------|-------------------|---------------------|----------------| | 0 | 5000 | 1e+06 | 30 | 0 | 0 | 129.594 | 131.153 | | 1 | 5000 | 1e+06 | 30 | 0.1 | 0 | 135.54 | 136.327 | | 2 | 5000 | 1e+06 | 30 | 0.01 | 0.5 | 135.148 | 129.771 | | 3 | 50000 | 100000 | 30 | 0 | 0 | 145.764 | 144.096 | ## Follow ups If this design is accepted, we will convert the other ML algorithms that use this aggregator pattern to this new hierarchy in follow up PRs. Author: sethah <seth.hendrickson16@gmail.com> Author: sethah <shendrickson@cloudera.com> Closes #17094 from sethah/ml_aggregators. --- .../DifferentiableLossAggregator.scala | 88 +++++ .../aggregator/LeastSquaresAggregator.scala | 224 ++++++++++++ .../loss/DifferentiableRegularization.scala | 71 ++++ .../spark/ml/optim/loss/RDDLossFunction.scala | 72 ++++ .../ml/regression/LinearRegression.scala | 327 +----------------- .../DifferentiableLossAggregatorSuite.scala | 160 +++++++++ .../LeastSquaresAggregatorSuite.scala | 157 +++++++++ .../DifferentiableRegularizationSuite.scala | 61 ++++ .../ml/optim/loss/RDDLossFunctionSuite.scala | 83 +++++ 9 files changed, 930 insertions(+), 313 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala new file mode 100644 index 0000000000..403c28ff73 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} + +/** + * A parent trait for aggregators used in fitting MLlib models. This parent trait implements + * some of the common code shared between concrete instances of aggregators. Subclasses of this + * aggregator need only implement the `add` method. + * + * @tparam Datum The type of the instances added to the aggregator to update the loss and gradient. + * @tparam Agg Specialization of [[DifferentiableLossAggregator]]. Classes that subclass this + * type need to use this parameter to specify the concrete type of the aggregator. + */ +private[ml] trait DifferentiableLossAggregator[ + Datum, + Agg <: DifferentiableLossAggregator[Datum, Agg]] extends Serializable { + + self: Agg => // enforce classes that extend this to be the same type as `Agg` + + protected var weightSum: Double = 0.0 + protected var lossSum: Double = 0.0 + + /** The dimension of the gradient array. */ + protected val dim: Int + + /** Array of gradient values that are mutated when new instances are added to the aggregator. */ + protected lazy val gradientSumArray: Array[Double] = Array.ofDim[Double](dim) + + /** Add a single data point to this aggregator. */ + def add(instance: Datum): Agg + + /** Merge two aggregators. The `this` object will be modified in place and returned. */ + def merge(other: Agg): Agg = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"${getClass.getSimpleName}. Expecting $dim but got ${other.dim}.") + + if (other.weightSum != 0) { + weightSum += other.weightSum + lossSum += other.lossSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + /** The current weighted averaged gradient. */ + def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but was $weightSum.") + val result = Vectors.dense(gradientSumArray.clone()) + BLAS.scal(1.0 / weightSum, result) + result + } + + /** Weighted count of instances in this aggregator. */ + def weight: Double = weightSum + + /** The current loss value of this aggregator. */ + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but was $weightSum.") + lossSum / weightSum + } + +} + diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala new file mode 100644 index 0000000000..1994b0e40e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregator.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in an online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * For improving the convergence rate during the optimization process, and also preventing against + * features with very large variances exerting an overly large influence during model training, + * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce + * the condition number, and then trains the model in scaled space but returns the coefficients in + * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache + * the standardized dataset since it will create a lot of overhead. As a result, we perform the + * scaling implicitly when we compute the objective function. The following is the mathematical + * derivation. + * + * Note that we don't deal with intercept by adding bias here, because the intercept + * can be computed using closed form after the coefficients are converged. + * See this discussion for detail. + * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + * + * When training with intercept enabled, + * The objective function in the scaled space is given by + * + * <blockquote> + * $$ + * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, + * $$ + * </blockquote> + * + * where $\bar{x_i}$ is the mean of $x_i$, $\hat{x_i}$ is the standard deviation of $x_i$, + * $\bar{y}$ is the mean of label, and $\hat{y}$ is the standard deviation of label. + * + * If we fitting the intercept disabled (that is forced through 0.0), + * we can use the same equation except we set $\bar{y}$ and $\bar{x_i}$ to 0 instead + * of the respective means. + * + * This can be rewritten as + * + * <blockquote> + * $$ + * \begin{align} + * L &= 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} + * + \bar{y} / \hat{y}||^2 \\ + * &= 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 + * \end{align} + * $$ + * </blockquote> + * + * where $w_i^\prime$ is the effective coefficients defined by $w_i/\hat{x_i}$, offset is + * + * <blockquote> + * $$ + * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. + * $$ + * </blockquote> + * + * and diff is + * + * <blockquote> + * $$ + * \sum_i w_i^\prime x_i - y / \hat{y} + offset + * $$ + * </blockquote> + * + * Note that the effective coefficients and offset don't depend on training dataset, + * so they can be precomputed. + * + * Now, the first derivative of the objective function in scaled space is + * + * <blockquote> + * $$ + * \frac{\partial L}{\partial w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} + * $$ + * </blockquote> + * + * However, $(x_i - \bar{x_i})$ will densify the computation, so it's not + * an ideal formula when the training dataset is sparse format. + * + * This can be addressed by adding the dense $\bar{x_i} / \hat{x_i}$ terms + * in the end by keeping the sum of diff. The first derivative of total + * objective function from all the samples is + * + * + * <blockquote> + * $$ + * \begin{align} + * \frac{\partial L}{\partial w_i} &= + * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} \\ + * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i} / \hat{x_i}) \\ + * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) + * \end{align} + * $$ + * </blockquote> + * + * where $correction_i = - diffSum \bar{x_i} / \hat{x_i}$ + * + * A simple math can show that diffSum is actually zero, so we don't even + * need to add the correction terms in the end. From the definition of diff, + * + * <blockquote> + * $$ + * \begin{align} + * diffSum &= \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) + * / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) \\ + * &= N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y} - \bar{y}) / \hat{y}) \\ + * &= 0 + * \end{align} + * $$ + * </blockquote> + * + * As a result, the first derivative of the total objective function only depends on + * the training dataset, which can be easily computed in distributed fashion, and is + * sparse format friendly. + * + * <blockquote> + * $$ + * \frac{\partial L}{\partial w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + * $$ + * </blockquote> + * + * @note The constructor is curried, since the cost function will repeatedly create new versions + * of this class for different coefficient vectors. + * + * @param labelStd The standard deviation value of the label. + * @param labelMean The mean value of the label. + * @param fitIntercept Whether to fit an intercept term. + * @param bcFeaturesStd The broadcast standard deviation values of the features. + * @param bcFeaturesMean The broadcast mean values of the features. + * @param bcCoefficients The broadcast coefficients corresponding to the features. + */ +private[ml] class LeastSquaresAggregator( + labelStd: Double, + labelMean: Double, + fitIntercept: Boolean, + bcFeaturesStd: Broadcast[Array[Double]], + bcFeaturesMean: Broadcast[Array[Double]])(bcCoefficients: Broadcast[Vector]) + extends DifferentiableLossAggregator[Instance, LeastSquaresAggregator] { + require(labelStd > 0.0, s"${this.getClass.getName} requires the label standard " + + s"deviation to be positive.") + + private val numFeatures = bcFeaturesStd.value.length + protected override val dim: Int = numFeatures + // make transient so we do not serialize between aggregation stages + @transient private lazy val featuresStd = bcFeaturesStd.value + @transient private lazy val effectiveCoefAndOffset = { + val coefficientsArray = bcCoefficients.value.toArray.clone() + val featuresMean = bcFeaturesMean.value + var sum = 0.0 + var i = 0 + val len = coefficientsArray.length + while (i < len) { + if (featuresStd(i) != 0.0) { + coefficientsArray(i) /= featuresStd(i) + sum += coefficientsArray(i) * featuresMean(i) + } else { + coefficientsArray(i) = 0.0 + } + i += 1 + } + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (Vectors.dense(coefficientsArray), offset) + } + // do not use tuple assignment above because it will circumvent the @transient tag + @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1 + @transient private lazy val offset = effectiveCoefAndOffset._2 + + /** + * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param instance The instance of data point to be added. + * @return This LeastSquaresAggregator object. + */ + def add(instance: Instance): LeastSquaresAggregator = { + instance match { case Instance(label, weight, features) => + require(numFeatures == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $numFeatures but got ${features.size}.") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + + if (weight == 0.0) return this + + val diff = BLAS.dot(features, effectiveCoefficientsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + val localFeaturesStd = featuresStd + features.foreachActive { (index, value) => + val fStd = localFeaturesStd(index) + if (fStd != 0.0 && value != 0.0) { + localGradientSumArray(index) += weight * diff * value / fStd + } + } + lossSum += weight * diff * diff / 2.0 + } + weightSum += weight + this + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala new file mode 100644 index 0000000000..118c0ebfa5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import breeze.optimize.DiffFunction + +/** + * A Breeze diff function which represents a cost function for differentiable regularization + * of parameters. e.g. L2 regularization: 1 / 2 regParam * beta dot beta + * + * @tparam T The type of the coefficients being regularized. + */ +private[ml] trait DifferentiableRegularization[T] extends DiffFunction[T] { + + /** Magnitude of the regularization penalty. */ + def regParam: Double + +} + +/** + * A Breeze diff function for computing the L2 regularized loss and gradient of an array of + * coefficients. + * + * @param regParam The magnitude of the regularization. + * @param shouldApply A function (Int => Boolean) indicating whether a given index should have + * regularization applied to it. + * @param featuresStd Option indicating whether the regularization should be scaled by the standard + * deviation of the features. + */ +private[ml] class L2Regularization( + val regParam: Double, + shouldApply: Int => Boolean, + featuresStd: Option[Array[Double]]) extends DifferentiableRegularization[Array[Double]] { + + override def calculate(coefficients: Array[Double]): (Double, Array[Double]) = { + var sum = 0.0 + val gradient = new Array[Double](coefficients.length) + coefficients.indices.filter(shouldApply).foreach { j => + val coef = coefficients(j) + featuresStd match { + case Some(stds) => + val std = stds(j) + if (std != 0.0) { + val temp = coef / (std * std) + sum += coef * temp + gradient(j) = regParam * temp + } else { + 0.0 + } + case None => + sum += coef * coef + gradient(j) = coef * regParam + } + } + (0.5 * sum * regParam, gradient) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala new file mode 100644 index 0000000000..3b1618eb0b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import scala.reflect.ClassTag + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.DiffFunction + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator +import org.apache.spark.rdd.RDD + +/** + * This class computes the gradient and loss of a differentiable loss function by mapping a + * [[DifferentiableLossAggregator]] over an [[RDD]] of [[Instance]]s. The loss function is the + * sum of the loss computed on a single instance across all points in the RDD. Therefore, the actual + * analytical form of the loss function is specified by the aggregator, which computes each points + * contribution to the overall loss. + * + * A differentiable regularization component can also be added by providing a + * [[DifferentiableRegularization]] loss function. + * + * @param instances + * @param getAggregator A function which gets a new loss aggregator in every tree aggregate step. + * @param regularization An option representing the regularization loss function to apply to the + * coefficients. + * @param aggregationDepth The aggregation depth of the tree aggregation step. + * @tparam Agg Specialization of [[DifferentiableLossAggregator]], representing the concrete type + * of the aggregator. + */ +private[ml] class RDDLossFunction[ + T: ClassTag, + Agg <: DifferentiableLossAggregator[T, Agg]: ClassTag]( + instances: RDD[T], + getAggregator: (Broadcast[Vector] => Agg), + regularization: Option[DifferentiableRegularization[Array[Double]]], + aggregationDepth: Int = 2) + extends DiffFunction[BDV[Double]] { + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val bcCoefficients = instances.context.broadcast(Vectors.fromBreeze(coefficients)) + val thisAgg = getAggregator(bcCoefficients) + val seqOp = (agg: Agg, x: T) => agg.add(x) + val combOp = (agg1: Agg, agg2: Agg) => agg1.merge(agg2) + val newAgg = instances.treeAggregate(thisAgg)(seqOp, combOp, aggregationDepth) + val gradient = newAgg.gradient + val regLoss = regularization.map { regFun => + val (regLoss, regGradient) = regFun.calculate(coefficients.data) + BLAS.axpy(1.0, Vectors.dense(regGradient), gradient) + regLoss + }.getOrElse(0.0) + bcCoefficients.destroy(blocking = false) + (newAgg.loss + regLoss, gradient.asBreeze.toDenseVector) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index eaad549852..db5ac4f14b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -20,19 +20,20 @@ package org.apache.spark.ml.regression import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import breeze.stats.distributions.StudentsT import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator +import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ @@ -319,8 +320,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam - val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept), - $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam, $(aggregationDepth)) + val getAggregatorFunc = new LeastSquaresAggregator(yStd, yMean, $(fitIntercept), + bcFeaturesStd, bcFeaturesMean)(_) + val regularization = if (effectiveL2RegParam != 0.0) { + val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures + Some(new L2Regularization(effectiveL2RegParam, shouldApply, + if ($(standardization)) None else Some(featuresStd))) + } else { + None + } + val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization, + $(aggregationDepth)) val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) @@ -793,312 +803,3 @@ class LinearRegressionSummary private[regression] ( } -/** - * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, - * as used in linear regression for samples in sparse or dense vector in an online fashion. - * - * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of - * the corresponding joint dataset. - * - * For improving the convergence rate during the optimization process, and also preventing against - * features with very large variances exerting an overly large influence during model training, - * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce - * the condition number, and then trains the model in scaled space but returns the coefficients in - * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf - * - * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache - * the standardized dataset since it will create a lot of overhead. As a result, we perform the - * scaling implicitly when we compute the objective function. The following is the mathematical - * derivation. - * - * Note that we don't deal with intercept by adding bias here, because the intercept - * can be computed using closed form after the coefficients are converged. - * See this discussion for detail. - * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet - * - * When training with intercept enabled, - * The objective function in the scaled space is given by - * - * <blockquote> - * $$ - * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, - * $$ - * </blockquote> - * - * where $\bar{x_i}$ is the mean of $x_i$, $\hat{x_i}$ is the standard deviation of $x_i$, - * $\bar{y}$ is the mean of label, and $\hat{y}$ is the standard deviation of label. - * - * If we fitting the intercept disabled (that is forced through 0.0), - * we can use the same equation except we set $\bar{y}$ and $\bar{x_i}$ to 0 instead - * of the respective means. - * - * This can be rewritten as - * - * <blockquote> - * $$ - * \begin{align} - * L &= 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} - * + \bar{y} / \hat{y}||^2 \\ - * &= 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 - * \end{align} - * $$ - * </blockquote> - * - * where $w_i^\prime$ is the effective coefficients defined by $w_i/\hat{x_i}$, offset is - * - * <blockquote> - * $$ - * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. - * $$ - * </blockquote> - * - * and diff is - * - * <blockquote> - * $$ - * \sum_i w_i^\prime x_i - y / \hat{y} + offset - * $$ - * </blockquote> - * - * Note that the effective coefficients and offset don't depend on training dataset, - * so they can be precomputed. - * - * Now, the first derivative of the objective function in scaled space is - * - * <blockquote> - * $$ - * \frac{\partial L}{\partial w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} - * $$ - * </blockquote> - * - * However, $(x_i - \bar{x_i})$ will densify the computation, so it's not - * an ideal formula when the training dataset is sparse format. - * - * This can be addressed by adding the dense $\bar{x_i} / \hat{x_i}$ terms - * in the end by keeping the sum of diff. The first derivative of total - * objective function from all the samples is - * - * - * <blockquote> - * $$ - * \begin{align} - * \frac{\partial L}{\partial w_i} &= - * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} \\ - * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i} / \hat{x_i}) \\ - * &= 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) - * \end{align} - * $$ - * </blockquote> - * - * where $correction_i = - diffSum \bar{x_i} / \hat{x_i}$ - * - * A simple math can show that diffSum is actually zero, so we don't even - * need to add the correction terms in the end. From the definition of diff, - * - * <blockquote> - * $$ - * \begin{align} - * diffSum &= \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) - * / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) \\ - * &= N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y} - \bar{y}) / \hat{y}) \\ - * &= 0 - * \end{align} - * $$ - * </blockquote> - * - * As a result, the first derivative of the total objective function only depends on - * the training dataset, which can be easily computed in distributed fashion, and is - * sparse format friendly. - * - * <blockquote> - * $$ - * \frac{\partial L}{\partial w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - * $$ - * </blockquote> - * - * @param bcCoefficients The broadcast coefficients corresponding to the features. - * @param labelStd The standard deviation value of the label. - * @param labelMean The mean value of the label. - * @param fitIntercept Whether to fit an intercept term. - * @param bcFeaturesStd The broadcast standard deviation values of the features. - * @param bcFeaturesMean The broadcast mean values of the features. - */ -private class LeastSquaresAggregator( - bcCoefficients: Broadcast[Vector], - labelStd: Double, - labelMean: Double, - fitIntercept: Boolean, - bcFeaturesStd: Broadcast[Array[Double]], - bcFeaturesMean: Broadcast[Array[Double]]) extends Serializable { - - private var totalCnt: Long = 0L - private var weightSum: Double = 0.0 - private var lossSum = 0.0 - - private val dim = bcCoefficients.value.size - // make transient so we do not serialize between aggregation stages - @transient private lazy val featuresStd = bcFeaturesStd.value - @transient private lazy val effectiveCoefAndOffset = { - val coefficientsArray = bcCoefficients.value.toArray.clone() - val featuresMean = bcFeaturesMean.value - var sum = 0.0 - var i = 0 - val len = coefficientsArray.length - while (i < len) { - if (featuresStd(i) != 0.0) { - coefficientsArray(i) /= featuresStd(i) - sum += coefficientsArray(i) * featuresMean(i) - } else { - coefficientsArray(i) = 0.0 - } - i += 1 - } - val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 - (Vectors.dense(coefficientsArray), offset) - } - // do not use tuple assignment above because it will circumvent the @transient tag - @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1 - @transient private lazy val offset = effectiveCoefAndOffset._2 - - private lazy val gradientSumArray = Array.ofDim[Double](dim) - - /** - * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient - * of the objective function. - * - * @param instance The instance of data point to be added. - * @return This LeastSquaresAggregator object. - */ - def add(instance: Instance): this.type = { - instance match { case Instance(label, weight, features) => - - if (weight == 0.0) return this - - val diff = dot(features, effectiveCoefficientsVector) - label / labelStd + offset - - if (diff != 0) { - val localGradientSumArray = gradientSumArray - val localFeaturesStd = featuresStd - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += weight * diff * value / localFeaturesStd(index) - } - } - lossSum += weight * diff * diff / 2.0 - } - - totalCnt += 1 - weightSum += weight - this - } - } - - /** - * Merge another LeastSquaresAggregator, and update the loss and gradient - * of the objective function. - * (Note that it's in place merging; as a result, `this` object will be modified.) - * - * @param other The other LeastSquaresAggregator to be merged. - * @return This LeastSquaresAggregator object. - */ - def merge(other: LeastSquaresAggregator): this.type = { - - if (other.weightSum != 0) { - totalCnt += other.totalCnt - weightSum += other.weightSum - lossSum += other.lossSum - - var i = 0 - val localThisGradientSumArray = this.gradientSumArray - val localOtherGradientSumArray = other.gradientSumArray - while (i < dim) { - localThisGradientSumArray(i) += localOtherGradientSumArray(i) - i += 1 - } - } - this - } - - def count: Long = totalCnt - - def loss: Double = { - require(weightSum > 0.0, s"The effective number of instances should be " + - s"greater than 0.0, but $weightSum.") - lossSum / weightSum - } - - def gradient: Vector = { - require(weightSum > 0.0, s"The effective number of instances should be " + - s"greater than 0.0, but $weightSum.") - val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / weightSum, result) - result - } -} - -/** - * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. - * It returns the loss and gradient with L2 regularization at a particular point (coefficients). - * It's used in Breeze's convex optimization routines. - */ -private class LeastSquaresCostFun( - instances: RDD[Instance], - labelStd: Double, - labelMean: Double, - fitIntercept: Boolean, - standardization: Boolean, - bcFeaturesStd: Broadcast[Array[Double]], - bcFeaturesMean: Broadcast[Array[Double]], - effectiveL2regParam: Double, - aggregationDepth: Int) extends DiffFunction[BDV[Double]] { - - override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { - val coeffs = Vectors.fromBreeze(coefficients) - val bcCoeffs = instances.context.broadcast(coeffs) - val localFeaturesStd = bcFeaturesStd.value - - val leastSquaresAggregator = { - val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance) - val combOp = (c1: LeastSquaresAggregator, c2: LeastSquaresAggregator) => c1.merge(c2) - - instances.treeAggregate( - new LeastSquaresAggregator(bcCoeffs, labelStd, labelMean, fitIntercept, bcFeaturesStd, - bcFeaturesMean))(seqOp, combOp, aggregationDepth) - } - - val totalGradientArray = leastSquaresAggregator.gradient.toArray - bcCoeffs.destroy(blocking = false) - - val regVal = if (effectiveL2regParam == 0.0) { - 0.0 - } else { - var sum = 0.0 - coeffs.foreachActive { (index, value) => - // The following code will compute the loss of the regularization; also - // the gradient of the regularization, and add back to totalGradientArray. - sum += { - if (standardization) { - totalGradientArray(index) += effectiveL2regParam * value - value * value - } else { - if (localFeaturesStd(index) != 0.0) { - // If `standardization` is false, we still standardize the data - // to improve the rate of convergence; as a result, we have to - // perform this reverse standardization by penalizing each component - // differently to get effectively the same objective function when - // the training dataset is not standardized. - val temp = value / (localFeaturesStd(index) * localFeaturesStd(index)) - totalGradientArray(index) += effectiveL2regParam * temp - value * temp - } else { - 0.0 - } - } - } - } - 0.5 * effectiveL2regParam * sum - } - - (leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray)) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala new file mode 100644 index 0000000000..7a4faeb1c1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregatorSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ + +class DifferentiableLossAggregatorSuite extends SparkFunSuite { + + import DifferentiableLossAggregatorSuite.TestAggregator + + private val instances1 = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + private val instances2 = Seq( + Instance(0.2, 0.4, Vectors.dense(0.8, 2.5)), + Instance(0.8, 0.9, Vectors.dense(2.0, 1.3)), + Instance(1.5, 0.2, Vectors.dense(3.0, 0.2)) + ) + + private def assertEqual[T, Agg <: DifferentiableLossAggregator[T, Agg]]( + agg1: DifferentiableLossAggregator[T, Agg], + agg2: DifferentiableLossAggregator[T, Agg]): Unit = { + assert(agg1.weight === agg2.weight) + assert(agg1.loss === agg2.loss) + assert(agg1.gradient === agg2.gradient) + } + + test("empty aggregator") { + val numFeatures = 5 + val coef = Vectors.dense(Array.fill(numFeatures)(1.0)) + val agg = new TestAggregator(numFeatures)(coef) + withClue("cannot get loss for empty aggregator") { + intercept[IllegalArgumentException] { + agg.loss + } + } + withClue("cannot get gradient for empty aggregator") { + intercept[IllegalArgumentException] { + agg.gradient + } + } + } + + test("aggregator initialization") { + val numFeatures = 3 + val coef = Vectors.dense(Array.fill(numFeatures)(1.0)) + val agg = new TestAggregator(numFeatures)(coef) + agg.add(Instance(1.0, 0.3, Vectors.dense(Array.fill(numFeatures)(1.0)))) + assert(agg.gradient.size === 3) + assert(agg.weight === 0.3) + } + + test("merge aggregators") { + val coefficients = Vectors.dense(0.5, -0.1) + val agg1 = new TestAggregator(2)(coefficients) + val agg2 = new TestAggregator(2)(coefficients) + val aggBadDim = new TestAggregator(1)(Vectors.dense(0.5)) + aggBadDim.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + instances1.foreach(agg1.add) + + // merge incompatible aggregators + withClue("cannot merge aggregators with different dimensions") { + intercept[IllegalArgumentException] { + agg1.merge(aggBadDim) + } + } + + // merge empty other + val mergedEmptyOther = agg1.merge(agg2) + assertEqual(mergedEmptyOther, agg1) + assert(mergedEmptyOther === agg1) + + // merge empty this + val agg3 = new TestAggregator(2)(coefficients) + val mergedEmptyThis = agg3.merge(agg1) + assertEqual(mergedEmptyThis, agg1) + assert(mergedEmptyThis !== agg1) + + instances2.foreach(agg2.add) + val (loss1, weight1, grad1) = (agg1.loss, agg1.weight, agg1.gradient) + val (loss2, weight2, grad2) = (agg2.loss, agg2.weight, agg2.gradient) + val merged = agg1.merge(agg2) + + // check pointers are equal + assert(merged === agg1) + + // loss should be weighted average of the two individual losses + assert(merged.loss === (loss1 * weight1 + loss2 * weight2) / (weight1 + weight2)) + assert(merged.weight === weight1 + weight2) + + // gradient should be weighted average of individual gradients + val addedGradients = Vectors.dense(grad1.toArray.clone()) + BLAS.scal(weight1, addedGradients) + BLAS.axpy(weight2, grad2, addedGradients) + BLAS.scal(1 / (weight1 + weight2), addedGradients) + assert(merged.gradient === addedGradients) + } + + test("loss, gradient, weight") { + val coefficients = Vectors.dense(0.5, -0.1) + val agg = new TestAggregator(2)(coefficients) + instances1.foreach(agg.add) + val errors = instances1.map { case Instance(label, _, features) => + label - BLAS.dot(features, coefficients) + } + val expectedLoss = errors.zip(instances1).map { case (error: Double, instance: Instance) => + instance.weight * error * error / 2.0 + } + val expectedGradient = Vectors.dense(0.0, 0.0) + errors.zip(instances1).foreach { case (error, instance) => + BLAS.axpy(instance.weight * error, instance.features, expectedGradient) + } + BLAS.scal(1.0 / agg.weight, expectedGradient) + val weightSum = instances1.map(_.weight).sum + + assert(agg.weight ~== weightSum relTol 1e-5) + assert(agg.loss ~== expectedLoss.sum / weightSum relTol 1e-5) + assert(agg.gradient ~== expectedGradient relTol 1e-5) + } +} + +object DifferentiableLossAggregatorSuite { + /** + * Dummy aggregator that represents least squares cost with no intercept. + */ + class TestAggregator(numFeatures: Int)(coefficients: Vector) + extends DifferentiableLossAggregator[Instance, TestAggregator] { + + protected override val dim: Int = numFeatures + + override def add(instance: Instance): TestAggregator = { + val error = instance.label - BLAS.dot(coefficients, instance.features) + weightSum += instance.weight + lossSum += instance.weight * error * error / 2.0 + (0 until dim).foreach { j => + gradientSumArray(j) += instance.weight * error * instance.features(j) + } + this + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala new file mode 100644 index 0000000000..d1cb0d380e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LeastSquaresAggregatorSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.aggregator + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class LeastSquaresAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var instances: Array[Instance] = _ + @transient var instancesConstantFeature: Array[Instance] = _ + @transient var instancesConstantLabel: Array[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + instances = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + instancesConstantFeature = Array( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)), + Instance(2.0, 0.3, Vectors.dense(1.0, 0.5)) + ) + instancesConstantLabel = Array( + Instance(1.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(1.0, 0.3, Vectors.dense(4.0, 0.5)) + ) + } + + /** Get feature and label summarizers for provided data. */ + def getSummarizers( + instances: Array[Instance]): (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), + c._2.add(Vectors.dense(instance.label), instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.aggregate( + new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer + )(seqOp, combOp) + } + + /** Get summary statistics for some data and create a new LeastSquaresAggregator. */ + def getNewAggregator( + instances: Array[Instance], + coefficients: Vector, + fitIntercept: Boolean): LeastSquaresAggregator = { + val (featuresSummarizer, ySummarizer) = getSummarizers(instances) + val yStd = math.sqrt(ySummarizer.variance(0)) + val yMean = ySummarizer.mean(0) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd) + val featuresMean = featuresSummarizer.mean + val bcFeaturesMean = spark.sparkContext.broadcast(featuresMean.toArray) + val bcCoefficients = spark.sparkContext.broadcast(coefficients) + new LeastSquaresAggregator(yStd, yMean, fitIntercept, bcFeaturesStd, + bcFeaturesMean)(bcCoefficients) + } + + test("check sizes") { + val coefficients = Vectors.dense(1.0, 2.0) + val aggIntercept = getNewAggregator(instances, coefficients, fitIntercept = true) + val aggNoIntercept = getNewAggregator(instances, coefficients, fitIntercept = false) + instances.foreach(aggIntercept.add) + instances.foreach(aggNoIntercept.add) + + // least squares agg does not include intercept in its gradient array + assert(aggIntercept.gradient.size === 2) + assert(aggNoIntercept.gradient.size === 2) + } + + test("check correctness") { + /* + Check that the aggregator computes loss/gradient for: + 0.5 * sum_i=1^N ([sum_j=1^D beta_j * ((x_j - x_j,bar) / sigma_j)] - ((y - ybar) / sigma_y))^2 + */ + val coefficients = Vectors.dense(1.0, 2.0) + val numFeatures = coefficients.size + val (featuresSummarizer, ySummarizer) = getSummarizers(instances) + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val featuresMean = featuresSummarizer.mean.toArray + val yStd = math.sqrt(ySummarizer.variance(0)) + val yMean = ySummarizer.mean(0) + + val agg = getNewAggregator(instances, coefficients, fitIntercept = true) + instances.foreach(agg.add) + + // compute (y - pred) analytically + val errors = instances.map { case Instance(l, w, f) => + val scaledFeatures = (0 until numFeatures).map { j => + (f.toArray(j) - featuresMean(j)) / featuresStd(j) + }.toArray + val scaledLabel = (l - yMean) / yStd + BLAS.dot(coefficients, Vectors.dense(scaledFeatures)) - scaledLabel + } + + // compute expected loss sum analytically + val expectedLoss = errors.zip(instances).map { case (error, instance) => + instance.weight * error * error / 2.0 + } + + // compute gradient analytically from instances + val expectedGradient = Vectors.dense(0.0, 0.0) + errors.zip(instances).foreach { case (error, instance) => + val scaledFeatures = (0 until numFeatures).map { j => + instance.weight * instance.features.toArray(j) / featuresStd(j) + }.toArray + BLAS.axpy(error, Vectors.dense(scaledFeatures), expectedGradient) + } + + val weightSum = instances.map(_.weight).sum + BLAS.scal(1.0 / weightSum, expectedGradient) + assert(agg.loss ~== (expectedLoss.sum / weightSum) relTol 1e-5) + assert(agg.gradient ~== expectedGradient relTol 1e-5) + } + + test("check with zero standard deviation") { + val coefficients = Vectors.dense(1.0, 2.0) + val aggConstantFeature = getNewAggregator(instancesConstantFeature, coefficients, + fitIntercept = true) + instances.foreach(aggConstantFeature.add) + // constant features should not affect gradient + assert(aggConstantFeature.gradient(0) === 0.0) + + withClue("LeastSquaresAggregator does not support zero standard deviation of the label") { + intercept[IllegalArgumentException] { + getNewAggregator(instancesConstantLabel, coefficients, fitIntercept = true) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala new file mode 100644 index 0000000000..0794417a8d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import org.apache.spark.SparkFunSuite + +class DifferentiableRegularizationSuite extends SparkFunSuite { + + test("L2 regularization") { + val shouldApply = (_: Int) => true + val regParam = 0.3 + val coefficients = Array(1.0, 3.0, -2.0) + val numFeatures = coefficients.size + + // check without features standard + val regFun = new L2Regularization(regParam, shouldApply, None) + val (loss, grad) = regFun.calculate(coefficients) + assert(loss === 0.5 * regParam * coefficients.map(x => x * x).sum) + assert(grad === coefficients.map(_ * regParam)) + + // check with features standard + val featuresStd = Array(0.1, 1.1, 0.5) + val regFunStd = new L2Regularization(regParam, shouldApply, Some(featuresStd)) + val (lossStd, gradStd) = regFunStd.calculate(coefficients) + val expectedLossStd = 0.5 * regParam * (0 until numFeatures).map { j => + coefficients(j) * coefficients(j) / (featuresStd(j) * featuresStd(j)) + }.sum + val expectedGradientStd = (0 until numFeatures).map { j => + regParam * coefficients(j) / (featuresStd(j) * featuresStd(j)) + }.toArray + assert(lossStd === expectedLossStd) + assert(gradStd === expectedGradientStd) + + // check should apply + val shouldApply2 = (i: Int) => i == 1 + val regFunApply = new L2Regularization(regParam, shouldApply2, None) + val (lossApply, gradApply) = regFunApply.calculate(coefficients) + assert(lossApply === 0.5 * regParam * coefficients(1) * coefficients(1)) + assert(gradApply === Array(0.0, coefficients(1) * regParam, 0.0)) + + // check with zero features standard + val featuresStdZero = Array(0.1, 0.0, 0.5) + val regFunStdZero = new L2Regularization(regParam, shouldApply, Some(featuresStdZero)) + val (_, gradStdZero) = regFunStdZero.calculate(coefficients) + assert(gradStdZero(1) == 0.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala new file mode 100644 index 0000000000..cd5cebee5f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/loss/RDDLossFunctionSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.optim.loss + +import org.apache.spark.SparkFunSuite +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregatorSuite.TestAggregator +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD + +class RDDLossFunctionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var instances: RDD[Instance] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + instances = sc.parallelize(Seq( + Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)), + Instance(2.0, 0.3, Vectors.dense(4.0, 0.5)) + )) + } + + test("regularization") { + val coefficients = Vectors.dense(0.5, -0.1) + val regLossFun = new L2Regularization(0.1, (_: Int) => true, None) + val getAgg = (bvec: Broadcast[Vector]) => new TestAggregator(2)(bvec.value) + val lossNoReg = new RDDLossFunction(instances, getAgg, None) + val lossWithReg = new RDDLossFunction(instances, getAgg, Some(regLossFun)) + + val (loss1, grad1) = lossNoReg.calculate(coefficients.asBreeze.toDenseVector) + val (regLoss, regGrad) = regLossFun.calculate(coefficients.toArray) + val (loss2, grad2) = lossWithReg.calculate(coefficients.asBreeze.toDenseVector) + + BLAS.axpy(1.0, Vectors.fromBreeze(grad1), Vectors.dense(regGrad)) + assert(Vectors.dense(regGrad) ~== Vectors.fromBreeze(grad2) relTol 1e-5) + assert(loss1 + regLoss === loss2) + } + + test("empty RDD") { + val rdd = sc.parallelize(Seq.empty[Instance]) + val coefficients = Vectors.dense(0.5, -0.1) + val getAgg = (bv: Broadcast[Vector]) => new TestAggregator(2)(bv.value) + val lossFun = new RDDLossFunction(rdd, getAgg, None) + withClue("cannot calculate cost for empty dataset") { + intercept[IllegalArgumentException]{ + lossFun.calculate(coefficients.asBreeze.toDenseVector) + } + } + } + + test("versus aggregating on an iterable") { + val coefficients = Vectors.dense(0.5, -0.1) + val getAgg = (bv: Broadcast[Vector]) => new TestAggregator(2)(bv.value) + val lossFun = new RDDLossFunction(instances, getAgg, None) + val (loss, grad) = lossFun.calculate(coefficients.asBreeze.toDenseVector) + + // just map the aggregator over the instances array + val agg = new TestAggregator(2)(coefficients) + instances.collect().foreach(agg.add) + + assert(loss === agg.loss) + assert(Vectors.fromBreeze(grad) === agg.gradient) + } + +} -- GitLab