diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 51ede15d6c3671fd6b6d4d807e5a422085709fd0..9469acf62e13d795e25aac3f1a95c537eda0912f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -937,50 +937,47 @@ class BinaryLogisticRegressionSummary private[classification] ( * Two LogisticAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * - * @param coefficients The coefficients corresponding to the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. - * @param featuresStd The standard deviation values of the features. - * @param featuresMean The mean values of the features. */ private class LogisticAggregator( - coefficients: Vector, + private val numFeatures: Int, numClasses: Int, - fitIntercept: Boolean, - featuresStd: Array[Double], - featuresMean: Array[Double]) extends Serializable { + fitIntercept: Boolean) extends Serializable { private var weightSum = 0.0 private var lossSum = 0.0 - private val coefficientsArray = coefficients match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"coefficients only supports dense vector but got type ${coefficients.getClass}.") - } - - private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length - - private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length) + private val gradientSumArray = + Array.ofDim[Double](if (fitIntercept) numFeatures + 1 else numFeatures) /** * Add a new training instance to this LogisticAggregator, and update the loss and gradient * of the objective function. * * @param instance The instance of data point to be added. + * @param coefficients The coefficients corresponding to the features. + * @param featuresStd The standard deviation values of the features. * @return This LogisticAggregator object. */ - def add(instance: Instance): this.type = { + def add( + instance: Instance, + coefficients: Vector, + featuresStd: Array[Double]): this.type = { instance match { case Instance(label, weight, features) => - require(dim == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $dim but got ${features.size}.") + require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + + s" Expecting $numFeatures but got ${features.size}.") require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this - val localCoefficientsArray = coefficientsArray + val coefficientsArray = coefficients match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"coefficients only supports dense vector but got type ${coefficients.getClass}.") + } val localGradientSumArray = gradientSumArray numClasses match { @@ -990,11 +987,11 @@ private class LogisticAggregator( var sum = 0.0 features.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { - sum += localCoefficientsArray(index) * (value / featuresStd(index)) + sum += coefficientsArray(index) * (value / featuresStd(index)) } } sum + { - if (fitIntercept) localCoefficientsArray(dim) else 0.0 + if (fitIntercept) coefficientsArray(numFeatures) else 0.0 } } @@ -1007,7 +1004,7 @@ private class LogisticAggregator( } if (fitIntercept) { - localGradientSumArray(dim) += multiplier + localGradientSumArray(numFeatures) += multiplier } if (label > 0) { @@ -1034,8 +1031,8 @@ private class LogisticAggregator( * @return This LogisticAggregator object. */ def merge(other: LogisticAggregator): this.type = { - require(dim == other.dim, s"Dimensions mismatch when merging with another " + - s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + require(numFeatures == other.numFeatures, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $numFeatures but got ${other.numFeatures}.") if (other.weightSum != 0.0) { weightSum += other.weightSum @@ -1086,13 +1083,17 @@ private class LogisticCostFun( override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { val numFeatures = featuresStd.length val coeffs = Vectors.fromBreeze(coefficients) + val n = coeffs.size + val localFeaturesStd = featuresStd + val logisticAggregator = { - val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) + val seqOp = (c: LogisticAggregator, instance: Instance) => + c.add(instance, coeffs, localFeaturesStd) val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) instances.treeAggregate( - new LogisticAggregator(coeffs, numClasses, fitIntercept, featuresStd, featuresMean) + new LogisticAggregator(numFeatures, numClasses, fitIntercept) )(seqOp, combOp) }