Skip to content
Snippets Groups Projects
Commit e29ce319 authored by Holden Karau's avatar Holden Karau Committed by DB Tsai
Browse files

[SPARK-8963][ML] cleanup tests in linear regression suite

Simplify model weight assertions to use vector comparision, switch to using absTol when comparing with 0.0 intercepts

Author: Holden Karau <holden@pigscanfly.ca>

Closes #7327 from holdenk/SPARK-8913-cleanup-tests-from-SPARK-8700-logistic-regression and squashes the following commits:

5bac185 [Holden Karau] Simplify model weight assertions to use vector comparision, switch to using absTol when comparing with 0.0 intercepts
parent 69165330
No related branches found
No related tags found
No related merge requests found
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
package org.apache.spark.ml.regression package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
...@@ -75,11 +75,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -75,11 +75,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V3. 7.198257 as.numeric.data.V3. 7.198257
*/ */
val interceptR = 6.298698 val interceptR = 6.298698
val weightsR = Array(4.700706, 7.199082) val weightsR = Vectors.dense(4.700706, 7.199082)
assert(model.intercept ~== interceptR relTol 1E-3) assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights ~= weightsR relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach { model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) => case Row(features: DenseVector, prediction1: Double) =>
...@@ -104,11 +103,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -104,11 +103,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V2. 6.995908 as.numeric.data.V2. 6.995908
as.numeric.data.V3. 5.275131 as.numeric.data.V3. 5.275131
*/ */
val weightsR = Array(6.995908, 5.275131) val weightsR = Vectors.dense(6.995908, 5.275131)
assert(model.intercept ~== 0 relTol 1E-3) assert(model.intercept ~== 0 absTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights ~= weightsR relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
/* /*
Then again with the data with no intercept: Then again with the data with no intercept:
> weightsWithoutIntercept > weightsWithoutIntercept
...@@ -118,11 +116,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -118,11 +116,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data3.V2. 4.70011 as.numeric.data3.V2. 4.70011
as.numeric.data3.V3. 7.19943 as.numeric.data3.V3. 7.19943
*/ */
val weightsWithoutInterceptR = Array(4.70011, 7.19943) val weightsWithoutInterceptR = Vectors.dense(4.70011, 7.19943)
assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3) assert(modelWithoutIntercept.intercept ~== 0 absTol 1E-3)
assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3) assert(modelWithoutIntercept.weights ~= weightsWithoutInterceptR relTol 1E-3)
assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
} }
test("linear regression with intercept with L1 regularization") { test("linear regression with intercept with L1 regularization") {
...@@ -139,11 +136,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -139,11 +136,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V3. 6.679841 as.numeric.data.V3. 6.679841
*/ */
val interceptR = 6.24300 val interceptR = 6.24300
val weightsR = Array(4.024821, 6.679841) val weightsR = Vectors.dense(4.024821, 6.679841)
assert(model.intercept ~== interceptR relTol 1E-3) assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights ~= weightsR relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach { model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) => case Row(features: DenseVector, prediction1: Double) =>
...@@ -169,11 +165,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -169,11 +165,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V3. 4.772913 as.numeric.data.V3. 4.772913
*/ */
val interceptR = 0.0 val interceptR = 0.0
val weightsR = Array(6.299752, 4.772913) val weightsR = Vectors.dense(6.299752, 4.772913)
assert(model.intercept ~== interceptR relTol 1E-3) assert(model.intercept ~== interceptR absTol 1E-5)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights ~= weightsR relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach { model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) => case Row(features: DenseVector, prediction1: Double) =>
...@@ -197,11 +192,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -197,11 +192,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V3. 4.926260 as.numeric.data.V3. 4.926260
*/ */
val interceptR = 5.269376 val interceptR = 5.269376
val weightsR = Array(3.736216, 5.712356) val weightsR = Vectors.dense(3.736216, 5.712356)
assert(model.intercept ~== interceptR relTol 1E-3) assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights ~= weightsR relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach { model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) => case Row(features: DenseVector, prediction1: Double) =>
...@@ -227,11 +221,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -227,11 +221,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V3. 4.214502 as.numeric.data.V3. 4.214502
*/ */
val interceptR = 0.0 val interceptR = 0.0
val weightsR = Array(5.522875, 4.214502) val weightsR = Vectors.dense(5.522875, 4.214502)
assert(model.intercept ~== interceptR relTol 1E-3) assert(model.intercept ~== interceptR absTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights ~== weightsR relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach { model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) => case Row(features: DenseVector, prediction1: Double) =>
...@@ -255,11 +248,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -255,11 +248,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.data.V3. 5.200403 as.numeric.data.V3. 5.200403
*/ */
val interceptR = 5.696056 val interceptR = 5.696056
val weightsR = Array(3.670489, 6.001122) val weightsR = Vectors.dense(3.670489, 6.001122)
assert(model.intercept ~== interceptR relTol 1E-3) assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights ~== weightsR relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach { model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) => case Row(features: DenseVector, prediction1: Double) =>
...@@ -285,11 +277,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -285,11 +277,10 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
as.numeric.dataM.V3. 4.322251 as.numeric.dataM.V3. 4.322251
*/ */
val interceptR = 0.0 val interceptR = 0.0
val weightsR = Array(5.673348, 4.322251) val weightsR = Vectors.dense(5.673348, 4.322251)
assert(model.intercept ~== interceptR relTol 1E-3) assert(model.intercept ~== interceptR absTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights ~= weightsR relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach { model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) => case Row(features: DenseVector, prediction1: Double) =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment