Skip to content
Snippets Groups Projects
Commit 4d97be95 authored by Holden Karau's avatar Holden Karau Committed by DB Tsai
Browse files

[SPARK-9204][ML] Add default params test for linearyregression suite

Author: Holden Karau <holden@pigscanfly.ca>

Closes #7553 from holdenk/SPARK-9204-add-default-params-test-to-linear-regression and squashes the following commits:

630ba19 [Holden Karau] style fix
faa08a3 [Holden Karau] Add default params test for linearyregression suite
parent a3c7a3ce
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.spark.ml.regression package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
...@@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
test("params") {
ParamsSuite.checkParams(new LinearRegression)
val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0)
ParamsSuite.checkParams(model)
}
test("linear regression: default params") {
val lir = new LinearRegression
assert(lir.getLabelCol === "label")
assert(lir.getFeaturesCol === "features")
assert(lir.getPredictionCol === "prediction")
assert(lir.getRegParam === 0.0)
assert(lir.getElasticNetParam === 0.0)
assert(lir.getFitIntercept)
val model = lir.fit(dataset)
model.transform(dataset)
.select("label", "prediction")
.collect()
assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
assert(model.intercept !== 0.0)
assert(model.hasParent)
}
test("linear regression with intercept without regularization") { test("linear regression with intercept without regularization") {
val trainer = new LinearRegression val trainer = new LinearRegression
val model = trainer.fit(dataset) val model = trainer.fit(dataset)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment