diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index 143bf539b0afebf1b8a250cba5d38e1bd797bb23..9c495512422bacb0b8418a35515e29df2ef1583a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -89,7 +89,7 @@ private[ml] class IterativelyReweightedLeastSquares( val oldCoefficients = oldModel.coefficients val coefficients = model.coefficients BLAS.axpy(-1.0, coefficients, oldCoefficients) - val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) => + val maxTolOfCoefficients = oldCoefficients.toArray.foldLeft(0.0) { (x, y) => math.max(math.abs(x), math.abs(y)) } val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index c4f41d07800c57c7792e6b42b4a1f5984da28607..fdeadaf2749719ca58f442a8993fe0327776749a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -335,6 +335,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val throw new SparkException(msg) } + require(numFeatures > 0 || $(fitIntercept), + "GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " + + "set to false. To fit a model with 0 features, fitIntercept must be set to true." ) + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ea059858a58b75fd2396603d49527007efaaaad4..add28a72b6808fe3f3d22c418f1a125ec1edcb65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -743,6 +743,61 @@ class GeneralizedLinearRegressionSuite } } + test("generalized linear regression: intercept only") { + /* + R code: + + library(statmod) + y <- c(1.0, 0.5, 0.7, 0.3) + w <- c(1, 2, 3, 4) + for (fam in list(gaussian(), poisson(), binomial(), Gamma(), tweedie(1.6))) { + model1 <- glm(y ~ 1, family = fam) + model2 <- glm(y ~ 1, family = fam, weights = w) + print(as.vector(c(coef(model1), coef(model2)))) + } + [1] 0.625 0.530 + [1] -0.4700036 -0.6348783 + [1] 0.5108256 0.1201443 + [1] 1.600000 1.886792 + [1] 1.325782 1.463641 + */ + + val dataset = Seq( + Instance(1.0, 1.0, Vectors.zeros(0)), + Instance(0.5, 2.0, Vectors.zeros(0)), + Instance(0.7, 3.0, Vectors.zeros(0)), + Instance(0.3, 4.0, Vectors.zeros(0)) + ).toDF() + + val expected = Seq(0.625, 0.530, -0.4700036, -0.6348783, 0.5108256, 0.1201443, + 1.600000, 1.886792, 1.325782, 1.463641) + + import GeneralizedLinearRegression._ + + var idx = 0 + for (family <- Seq("gaussian", "poisson", "binomial", "gamma", "tweedie")) { + for (useWeight <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily(family) + if (useWeight) trainer.setWeightCol("weight") + if (family == "tweedie") trainer.setVariancePower(1.6) + val model = trainer.fit(dataset) + val actual = model.intercept + assert(actual ~== expected(idx) absTol 1E-3, "Model mismatch: intercept only GLM with " + + s"useWeight = $useWeight and family = $family.") + assert(model.coefficients === new DenseVector(Array.empty[Double])) + idx += 1 + } + } + + // throw exception for empty model + val trainer = new GeneralizedLinearRegression().setFitIntercept(false) + withClue("Specified model is empty with neither intercept nor feature") { + intercept[IllegalArgumentException] { + trainer.fit(dataset) + } + } + } + test("glm summary: gaussian family with weight") { /* R code: