Skip to content
Snippets Groups Projects
Commit 92654366 authored by wm624@hotmail.com's avatar wm624@hotmail.com Committed by Joseph K. Bradley
Browse files

[SPARK-19382][ML] Test sparse vectors in LinearSVCSuite

## What changes were proposed in this pull request?

Add unit tests for testing SparseVector.

We can't add mixed DenseVector and SparseVector test case, as discussed in JIRA 19382.

 def merge(other: MultivariateOnlineSummarizer): this.type = {
if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) {
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got $
{other.n}

.")

## How was this patch tested?

Unit tests

Author: wm624@hotmail.com <wm624@hotmail.com>
Author: Miao Wang <wangmiao1981@users.noreply.github.com>

Closes #16784 from wangmiao1981/bk.
parent 9991c2da
No related branches found
No related tags found
No related merge requests found
......@@ -24,12 +24,13 @@ import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.LinearSVCSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.udf
class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
......@@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
@transient var smallValidationDataset: Dataset[_] = _
@transient var binaryDataset: Dataset[_] = _
@transient var smallSparseBinaryDataset: Dataset[_] = _
@transient var smallSparseValidationDataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
......@@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF()
smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF()
binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF()
// Dataset for testing SparseVector
val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse
val sparse = udf(toSparse)
smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features))
smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features))
}
/**
......@@ -68,6 +79,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val model = svm.fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
val sparseModel = svm.fit(smallSparseBinaryDataset)
checkModels(model, sparseModel)
}
test("Linear SVC binary classification with regularization") {
......@@ -75,6 +88,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val model = svm.setRegParam(0.1).fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
val sparseModel = svm.fit(smallSparseBinaryDataset)
checkModels(model, sparseModel)
}
test("params") {
......@@ -235,7 +250,7 @@ object LinearSVCSuite {
"aggregationDepth" -> 3
)
// Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
// Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
def generateSVMInput(
intercept: Double,
weights: Array[Double],
......@@ -252,5 +267,10 @@ object LinearSVCSuite {
y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}
def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = {
assert(model1.intercept == model2.intercept)
assert(model1.coefficients.equals(model2.coefficients))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment