Skip to content
Snippets Groups Projects
Commit 7e3a1ada authored by coderxiang's avatar coderxiang Committed by Xiangrui Meng
Browse files

[MLlib] SPARK-3987: add test case on objective value for NNLS

Also update step parameter to pass the proposed test

Author: coderxiang <shuoxiangpub@gmail.com>

Closes #2965 from coderxiang/nnls-test and squashes the following commits:

24b06f9 [coderxiang] add test case on objective value for NNLS; update step parameter to pass the test
parent bfa614b1
No related branches found
No related tags found
No related merge requests found
......@@ -79,7 +79,7 @@ private[mllib] object NNLS {
// stopping condition
def stop(step: Double, ndir: Double, nx: Double): Boolean = {
((step.isNaN) // NaN
|| (step < 1e-6) // too small or negative
|| (step < 1e-7) // too small or negative
|| (step > 1e40) // too small; almost certainly numerical problems
|| (ndir < 1e-12 * nx) // gradient relatively too small
|| (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk
......
......@@ -37,6 +37,12 @@ class NNLSSuite extends FunSuite {
(ata, atb)
}
/** Compute the objective value */
def computeObjectiveValue(ata: DoubleMatrix, atb: DoubleMatrix, x: DoubleMatrix): Double = {
val res = (x.transpose().mmul(ata).mmul(x)).mul(0.5).sub(atb.dot(x))
res.get(0)
}
test("NNLS: exact solution cases") {
val n = 20
val rand = new Random(12346)
......@@ -79,4 +85,28 @@ class NNLSSuite extends FunSuite {
assert(x(i) >= 0)
}
}
test("NNLS: objective value test") {
val n = 5
val ata = new DoubleMatrix(5, 5
, 517399.13534, 242529.67289, -153644.98976, 130802.84503, -798452.29283
, 242529.67289, 126017.69765, -75944.21743, 81785.36128, -405290.60884
, -153644.98976, -75944.21743, 46986.44577, -45401.12659, 247059.51049
, 130802.84503, 81785.36128, -45401.12659, 67457.31310, -253747.03819
, -798452.29283, -405290.60884, 247059.51049, -253747.03819, 1310939.40814
)
val atb = new DoubleMatrix(5, 1,
-31755.05710, 13047.14813, -20191.24443, 25993.77580, 11963.55017)
/** reference solution obtained from matlab function quadprog */
val refx = new DoubleMatrix(Array(34.90751, 103.96254, 0.00000, 27.82094, 58.79627))
val refObj = computeObjectiveValue(ata, atb, refx)
val ws = NNLS.createWorkspace(n)
val x = new DoubleMatrix(NNLS.solve(ata, atb, ws))
val obj = computeObjectiveValue(ata, atb, x)
assert(obj < refObj + 1E-5)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment