From 4d97be95300f729391c17b4c162e3c7fba09b8bf Mon Sep 17 00:00:00 2001
From: Holden Karau <holden@pigscanfly.ca>
Date: Mon, 20 Jul 2015 22:15:10 -0700
Subject: [PATCH] [SPARK-9204][ML] Add default params test for
 linearyregression suite

Author: Holden Karau <holden@pigscanfly.ca>

Closes #7553 from holdenk/SPARK-9204-add-default-params-test-to-linear-regression and squashes the following commits:

630ba19 [Holden Karau] style fix
faa08a3 [Holden Karau] Add default params test for linearyregression suite
---
 .../ml/regression/LinearRegressionSuite.scala | 25 +++++++++++++++++++
 1 file changed, 25 insertions(+)

diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 374002c5b4..7cdda3db88 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.regression
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
@@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
 
   }
 
+  test("params") {
+    ParamsSuite.checkParams(new LinearRegression)
+    val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0)
+    ParamsSuite.checkParams(model)
+  }
+
+  test("linear regression: default params") {
+    val lir = new LinearRegression
+    assert(lir.getLabelCol === "label")
+    assert(lir.getFeaturesCol === "features")
+    assert(lir.getPredictionCol === "prediction")
+    assert(lir.getRegParam === 0.0)
+    assert(lir.getElasticNetParam === 0.0)
+    assert(lir.getFitIntercept)
+    val model = lir.fit(dataset)
+    model.transform(dataset)
+      .select("label", "prediction")
+      .collect()
+    assert(model.getFeaturesCol === "features")
+    assert(model.getPredictionCol === "prediction")
+    assert(model.intercept !== 0.0)
+    assert(model.hasParent)
+  }
+
   test("linear regression with intercept without regularization") {
     val trainer = new LinearRegression
     val model = trainer.fit(dataset)
-- 
GitLab