Skip to content
Snippets Groups Projects
Commit a19a1bb5 authored by Yanbo Liang's avatar Yanbo Liang
Browse files

[SPARK-16356][FOLLOW-UP][ML] Enforce ML test of exception for local/distributed Dataset.

## What changes were proposed in this pull request?
#14035 added ```testImplicits``` to ML unit tests and promoted ```toDF()```, but left one minor issue at ```VectorIndexerSuite```. If we create the DataFrame by ```Seq(...).toDF()```, it will throw different error/exception compared with ```sc.parallelize(Seq(...)).toDF()``` for one of the test cases.
After in-depth study, I found it was caused by different behavior of local and distributed Dataset if the UDF failed at ```assert```. If the data is local Dataset, it throws ```AssertionError``` directly; If the data is distributed Dataset, it throws ```SparkException``` which is the wrapper of ```AssertionError```. I think we should enforce this test to cover both case.

## How was this patch tested?
Unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #15261 from yanboliang/spark-16356.
parent 37eb9184
No related branches found
No related tags found
No related merge requests found
......@@ -88,9 +88,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
densePoints1 = densePoints1Seq.map(FeatureData).toDF()
sparsePoints1 = sparsePoints1Seq.map(FeatureData).toDF()
// TODO: If we directly use `toDF` without parallelize, the test in
// "Throws error when given RDDs with different size vectors" is failed for an unknown reason.
densePoints2 = sc.parallelize(densePoints2Seq, 2).map(FeatureData).toDF()
densePoints2 = densePoints2Seq.map(FeatureData).toDF()
sparsePoints2 = sparsePoints2Seq.map(FeatureData).toDF()
badPoints = badPointsSeq.map(FeatureData).toDF()
}
......@@ -121,10 +119,17 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
model.transform(densePoints1) // should work
model.transform(sparsePoints1) // should work
intercept[SparkException] {
// If the data is local Dataset, it throws AssertionError directly.
intercept[AssertionError] {
model.transform(densePoints2).collect()
logInfo("Did not throw error when fit, transform were called on vectors of different lengths")
}
// If the data is distributed Dataset, it throws SparkException
// which is the wrapper of AssertionError.
intercept[SparkException] {
model.transform(densePoints2.repartition(2)).collect()
logInfo("Did not throw error when fit, transform were called on vectors of different lengths")
}
intercept[SparkException] {
vectorIndexer.fit(badPoints)
logInfo("Did not throw error when fitting vectors of different lengths in same RDD.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment