Skip to content
Snippets Groups Projects
Commit 28dbde38 authored by Yuhao Yang's avatar Yuhao Yang Committed by Sean Owen
Browse files

[SPARK-7983] [MLLIB] Add require for one-based indices in loadLibSVMFile

jira: https://issues.apache.org/jira/browse/SPARK-7983

Customers frequently use zero-based indices in their LIBSVM files. No warnings or errors from Spark will be reported during their computation afterwards, and usually it will lead to wired result for many algorithms (like GBDT).

add a quick check.

Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #6538 from hhbyyh/loadSVM and squashes the following commits:

79d9c11 [Yuhao Yang] optimization as respond to comments
4310710 [Yuhao Yang] merge conflict
96460f1 [Yuhao Yang] merge conflict
20a2811 [Yuhao Yang] use require
6e4f8ca [Yuhao Yang] add check for ascending order
9956365 [Yuhao Yang] add ut for 0-based loadlibsvm exception
5bd1f9a [Yuhao Yang] add require for one-based in loadLIBSVM
parent d38cf217
No related branches found
No related tags found
No related merge requests found
......@@ -82,6 +82,18 @@ object MLUtils {
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
// check if indices are one-based and in ascending order
var previous = -1
var i = 0
val indicesLength = indices.length
while (i < indicesLength) {
val current = indices(i)
require(current > previous, "indices should be one-based and in ascending order" )
previous = current
i += 1
}
(label, indices.toArray, values.toArray)
}
......
......@@ -25,6 +25,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.base.Charsets
import com.google.common.io.Files
import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
......@@ -108,6 +109,40 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}
test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
val lines =
"""
|0
|0 0:4.0 4:5.0 6:6.0
""".stripMargin
val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString
intercept[SparkException] {
loadLibSVMFile(sc, path).collect()
}
Utils.deleteRecursively(tempDir)
}
test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
val lines =
"""
|0
|0 3:4.0 2:5.0 6:6.0
""".stripMargin
val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString
intercept[SparkException] {
loadLibSVMFile(sc, path).collect()
}
Utils.deleteRecursively(tempDir)
}
test("saveAsLibSVMFile") {
val examples = sc.parallelize(Seq(
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment